In [1]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "1"


from pathlib import Path
from dataclasses import dataclass

In [2]:
dataset_dir = Path("/mnt/ssd2/xin/repo/DART/Liebherr_Product")
repo_dir = Path("/mnt/ssd2/xin/repo/DART/diversification")
diffusers_dir = Path("/mnt/ssd2/xin/repo/diffusers")

MODEL_NAME = "sd1-5"

class_data_dir = repo_dir / "dreambooth" / "class_data" / MODEL_NAME
instance_data_dir = repo_dir / "instance_data"
script_dir = repo_dir / "dreambooth" / "scripts" / MODEL_NAME
script_dir.mkdir(parents=True, exist_ok=True)

## Training script

In [3]:
def get_prompt(obj, placeholder_token, plural=False, naive=False):

    if naive:
        return f"A photo of {placeholder_token} {obj}(s) on a construction site."
    else:
        if plural:
            return f"A photo of {placeholder_token} {obj}(s) on a construction site. The image is high quality and photorealistic, with one or several {placeholder_token} {obj}s visible from various angles and distances. The {placeholder_token} {obj}s may be partially visible, at a distance, or obscured, ensuring a variety of training examples for object detection. The background is complex, providing a realistic context."
        else:
            return f"A photo of a {placeholder_token} {obj} on a construction site. The image is high quality and photorealistic. The {placeholder_token} {obj} may be partially visible, at a distance, or obscured, ensuring a variety of training examples for object detection. The background is complex, providing a realistic context."


def get_unique_path(folder, base_name):
    base_path = Path(folder) / base_name
    counter = 1
    cur_path = base_path

    while cur_path.exists():
        cur_path = base_path.parent / f"{base_path.stem}_{counter}{base_path.suffix}"
        counter += 1

    return cur_path


def process_variables(
    obj,
    instance,
    repo_dir,
    hyperparameters,
    placeholder_token="sks",
    acc_config="default_config",
    report_to="wandb",
):
    # Process the variables based on obj and repo_dir
    obj_ = obj.replace(" ", "_")
    model_name = "runwayml/stable-diffusion-v1-5"
    obj_instance_data_dir = instance_data_dir / obj_ / instance
    script_path = get_unique_path(
        repo_dir / "dreambooth" / "scripts" / MODEL_NAME / obj_, instance + ".sh"
    )
    output_dir = Path(str(script_path).replace("scripts", "output").replace(".sh", ""))

    # whether to use instance-level or class-level data for class data generation
    if instance[0].islower():
        # if the instance is not capitalized, the instance name is supposed to be a subclass name with semantic meaning
        class_prompt = f"a photo of a {instance.replace('_',' ')} {obj}"
        obj_class_data_dir = class_data_dir / obj_ / instance

    elif instance[0].isupper() or instance.isdigit():
        class_prompt = f"a photo of {obj}"
        obj_class_data_dir = class_data_dir / obj_

    else:
        raise ValueError(
            "Instance name should either be a model name (i.e. start with a capital letter or a digit) or a subclass name (i.e. start with a lowercase letter)"
        )

    # set max training steps to #images*120
    num_instance_images = len(list(obj_instance_data_dir.glob("*.jpg")))
    max_train_steps = max(num_instance_images * 120, 800)

    return {
        "max_train_steps": max_train_steps,
        "model_name": model_name,
        "instance_data_dir": str(obj_instance_data_dir),
        "output_dir": str(output_dir),
        "script_path": str(script_path),
        "class_data_dir": str(obj_class_data_dir),
        "instance_prompt": f"a photo of a {placeholder_token} {obj}",
        "class_prompt": class_prompt,
        "validation_prompt": get_prompt(
            obj, placeholder_token, plural=False, naive=False
        ),
        "acc_config_path": Path.home()
        / ".cache/huggingface/accelerate"
        / (str(acc_config) + ".yaml"),
        "report_to": report_to,
        **hyperparameters.__dict__,
    }


def parse_variables(variables):
    command = f"""
#!/bin/bash

cd {diffusers_dir}/examples/dreambooth

export MODEL_NAME="{variables['model_name']}"
export INSTANCE_DATA_DIR="{variables['instance_data_dir']}"
export OUTPUT_DIR="{variables['output_dir']}"
export CLASS_DATA_DIR="{variables['class_data_dir']}"
export CONFIG_FILE="{variables['acc_config_path']}"

accelerate launch --config_file=$CONFIG_FILE\\
  train_dreambooth.py \\
  --pretrained_model_name_or_path=$MODEL_NAME \\
  --instance_data_dir=$INSTANCE_DATA_DIR \\
  --output_dir=$OUTPUT_DIR \\
  --instance_prompt="{variables['instance_prompt']}" \\
  --resolution={variables['resolution']} \\
  --train_batch_size={variables['train_batch_size']} \\
  --gradient_accumulation_steps={variables['gradient_accumulation_steps']} \\
  --learning_rate={variables['learning_rate']} \\
  --lr_scheduler="{variables['lr_scheduler']}" \\
  --lr_warmup_steps={variables['lr_warmup_steps']} \\
  --max_train_steps={variables['max_train_steps']} \\
  --checkpointing_steps={variables['checkpointing_steps']} \\
  --mixed_precision="{variables['mixed_precision']}" \\
  --prior_loss_weight={variables['prior_loss_weight']} \\
  --num_validation_images={variables['num_validation_images']} \\
"""

    validation_steps = variables.get("validation_steps", None)
    validation_epochs = variables.get("validation_epochs", None)
    if validation_steps:
        command += f"  --validation_steps={validation_steps} \\\n"
        command += f"  --validation_prompt=\"{variables['validation_prompt']}\"\n"
    if validation_epochs:
        command += f"  --validation_epochs={validation_epochs} \\\n"
        command += f"  --validation_prompt=\"{variables['validation_prompt']}\"\n"

    if variables["with_prior_preservation"]:
        command = command.rstrip() + " \\\n"  # replace \n with \\\n

        command += f"  --with_prior_preservation \\\n"

        command += f"  --class_data_dir=$CLASS_DATA_DIR \\\n"
        command += f"  --class_prompt=\"{variables['class_prompt']}\" \n"

    if variables.get("snr_gamma", None):
        command = command.rstrip() + " \\\n"
        command += f"  --snr_gamma={variables['snr_gamma']}\n"

    if variables.get("do_edm_style_training", None):
        command = command.rstrip() + " \\\n"
        command += f"  --do_edm_style_training\n"

    if variables.get("train_text_encoder", None):
        command = command.rstrip() + " \\\n"
        command += f"  --train_text_encoder\n"

    if variables.get("report_to", None):
        command = command.rstrip() + " \\\n"
        command += f'  --report_to="{variables["report_to"]}"\n'

    return command


def generate_command(
    placeholder_token, obj, instance, acc_config, report_to, hyperparameters
):
    variables = process_variables(
        obj,
        instance,
        repo_dir,
        hyperparameters,
        placeholder_token,
        acc_config,
        report_to,
    )
    command = parse_variables(variables)
    return command, variables


@dataclass
class Hyperparameters:
    resolution: int = 512
    train_batch_size: int = 1
    gradient_accumulation_steps: int = 1
    learning_rate: float = 5e-6
    lr_scheduler: str = "constant"
    lr_warmup_steps: int = 0
    # max_train_steps: int = 1200 # max_train_steps is set to #images*120 in process_variables
    checkpointing_steps: int = 500
    # validation_epochs: int = 50
    validation_steps: int = 500
    mixed_precision: str = "bf16"
    with_prior_preservation: bool = True
    prior_loss_weight: float = 1.0
    num_validation_images: int = 4
    train_text_encoder: bool = False
    snr_gamma: float = 5.0

In [4]:
obj = "articulated dump truck"
instance = "TA230"
placeholder_token = f"<{instance}>"

acc_config = "default_config"  # device
report_to = "wandb"
hyperparameters = Hyperparameters()

command, variables = generate_command(
    placeholder_token, obj, instance, acc_config, report_to, hyperparameters
)
print(command)


#!/bin/bash

cd /mnt/ssd2/xin/repo/diffusers/examples/dreambooth

export MODEL_NAME="runwayml/stable-diffusion-v1-5"
export INSTANCE_DATA_DIR="/mnt/ssd2/xin/repo/DART/diversification/instance_data/articulated_dump_truck/TA230"
export OUTPUT_DIR="/mnt/ssd2/xin/repo/DART/diversification/dreambooth/output/sd1-5/articulated_dump_truck/TA230_1"
export CLASS_DATA_DIR="/mnt/ssd2/xin/repo/DART/diversification/dreambooth/class_data/sd1-5/articulated_dump_truck"
export CONFIG_FILE="/home/chenxin/.cache/huggingface/accelerate/default_config.yaml"

accelerate launch --config_file=$CONFIG_FILE\
  train_dreambooth.py \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --instance_data_dir=$INSTANCE_DATA_DIR \
  --output_dir=$OUTPUT_DIR \
  --instance_prompt="a photo of a <TA230> articulated dump truck" \
  --resolution=512 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=1 \
  --learning_rate=5e-06 \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --max_train_steps=1200 \
  --check

In [5]:
# save the script
obj_ = obj.replace(" ", "_")
script_path = script_dir / f"{get_unique_path(script_dir/obj_,instance+'.sh')}"
script_path.parent.mkdir(parents=True, exist_ok=True)
with open(script_path, "w") as f:
    f.write(command)
print(f"Script saved at \n{script_path}")

Script saved at 
/mnt/ssd2/xin/repo/DART/diversification/dreambooth/scripts/sd1-5/articulated_dump_truck/TA230_1.sh
