In [1]:
import os, json

In [2]:
# some definitions first

FLAG_ON = "__FLAG_ON"

def args_to_cmd(args):
    ret = []
    for k, v in args.items():
        if v == FLAG_ON:
            ret.append(f"--{k}")
        else:
            if isinstance(v, str):
                v = f'"{v}"'
            ret.append(f"--{k}={v}")
    return " ".join(ret)

                
    
def create_args(datadir, resolution=512, batsize=2, gradacc=1, 
                lr=1e-6, lr_scheduler="constant", trainsteps=100, save_steps=500, prior_preservation=True, tune_encoder=False, 
                modelname="CompVis/stable-diffusion-v1-4", generate_every=-1, mode="full-ft", numvectors=1, suffix=None):
    args = {
        "pretrained_model_name_or_path": modelname,
        "instance_data_dir": os.path.join(datadir, "instance_images"),
        "class_data_dir": os.path.join(datadir, "class_images"),
        "output_dir": os.path.join(datadir, f"model-{mode}" + ("_"+suffix) if suffix is not None else ""),
        "resolution": resolution,
        "train_batch_size": batsize,
        "gradient_accumulation_steps": gradacc,
        "gradient_checkpointing": FLAG_ON,
        "use_8bit_adam": FLAG_ON,
        "learning_rate": lr,
        "lr_scheduler": lr_scheduler,   # cosine or constant or linear or ...
        "lr_warmup_steps": 50,
        "num_class_images": 200,
        "max_train_steps": trainsteps,
        "save_steps": save_steps,
        "generate_every": generate_every,
        "num_vectors_per_extra_token": numvectors,
    }
    if prior_preservation:
        args["with_prior_preservation"] = FLAG_ON
        args["prior_loss_weight"] = 1.0
        
      
    if mode == "full-ft":
        args["train_unet"] = FLAG_ON
        if tune_encoder:
            args["train_text_encoder"] = FLAG_ON
    elif mode == "kvemb-ft":
        args["train_kv_emb_only"] = FLAG_ON
    elif mode == "emb-ft":
        args["train_emb_only"] = FLAG_ON
    elif mode == "emb-mem-ft":
        args["train_emb_mem"] = FLAG_ON
    elif mode == "textenc-ft":
        args["train_textenc_only"] = FLAG_ON
    return args

In [3]:
modelname = "runwayml/stable-diffusion-v1-5" #"realisticVisionV13"   #"runwayml/stable-diffusion-v1-5" # "dreamlike-art/dreamlike-diffusion-1.0" #"Linaqruf/anything-v3.0" # "Lykon/DreamShaper"   # "CompVis/stable-diffusion-v1-4"
datasetdir = "datasets"
include = ["sophia_winsell"]   # ["teddybear", alvin_dog", "cat_statue", "elephant", "cat", "dog", "cheburashka"] # ["alvin_dog"]
exclude = []
use_prior_preservation = True
batsize = 2
lr_scheduler = "constant"
tune_encoder = True

numvectors = 1

In [4]:
# For full finetuning
# """
generate_every = 50
trainsteps = 500
# 200 total steps for objects
mode = "full-ft"
lr = 1e-6   # 5e-5 for kvemb, 1e-6 for full, 1e-4 and 1000 steps for emb
tune_encoder = False
use_prior_preservation = True

# """

In [5]:
# # For KV and emb finetuning
# generate_every = 100
# trainsteps = 1000
# mode = "kvemb-ft"
# lr = 1e-4  # 5e-5 for kvemb, 1e-6 for full, 1e-4 and 1000 steps for emb

In [6]:
# For emb-only finetuning
"""
generate_every = 100
trainsteps = 1000
mode = "emb-ft"
lr = 5e-5   # 5e-5 for kvemb, 1e-6 for full, 1e-4 and 1000 steps for emb
lr_scheduler = "constant"
use_prior_preservation = False
batsize = 4
numvectors = 4
# """

'\ngenerate_every = 100\ntrainsteps = 1000\nmode = "emb-ft"\nlr = 5e-5   # 5e-5 for kvemb, 1e-6 for full, 1e-4 and 1000 steps for emb\nlr_scheduler = "constant"\nuse_prior_preservation = False\nbatsize = 4\nnumvectors = 4\n# '

In [7]:
# For emb-mem finetuning
"""
generate_every = 100
trainsteps = 1000
mode = "emb-mem-ft"
lr = 5e-5   # 5e-5 for kvemb, 1e-6 for full, 1e-4 and 1000 steps for emb
lr_scheduler = "constant"
use_prior_preservation = False
batsize = 4
numvectors = 8
# """

'\ngenerate_every = 100\ntrainsteps = 1000\nmode = "emb-mem-ft"\nlr = 5e-5   # 5e-5 for kvemb, 1e-6 for full, 1e-4 and 1000 steps for emb\nlr_scheduler = "constant"\nuse_prior_preservation = False\nbatsize = 4\nnumvectors = 8\n# '

In [8]:
# # For text-encoder-only finetuning
# generate_every = 40
# trainsteps = 400
# mode = "textenc-ft"
# lr = 5e-5   # 5e-5 for kvemb, 1e-6 for full, 1e-4 and 1000 steps for emb

In [9]:
# for alvin dog: 50+100 updates

In [10]:
if mode == "emb-ft":
    numvectorses = [1, 2, 4, 16]
    numvectorses = [4]
elif mode == "emb-mem-ft":
    numvectorses = [4, 8, 32, 100, 1000]
elif mode == "full-ft":
    numvectorses = [1]
        
for numvectors in numvectorses:
    for x in os.listdir("datasets"):
        if len(include) > 0 and x not in include:
            continue
        if x in exclude:
            continue
        print("")
        print(x)
        print(("==="*30 + "\n") * 5)
        configfile = "finetune.config"
        if mode == "full-ft" or mode == "kvemb-ft": 
            configfile = "dreambooth.config"
        with open(os.path.join("datasets", x, configfile)) as f:
            settings = json.load(f)
            # required settings
            datadir = os.path.join("datasets", x)
            settings["instance_prompts"] = ";".join(settings["instance_prompts"])
            settings["class_prompts"] = ";".join(settings["class_prompts"])

            # override other default settings
            args = create_args(datadir, generate_every=generate_every, batsize=batsize,
                               trainsteps=trainsteps, mode=mode, lr=lr, lr_scheduler=lr_scheduler, 
                               prior_preservation=use_prior_preservation, tune_encoder=tune_encoder,
                               modelname=modelname, numvectors=numvectors, suffix=f"{numvectors}vectors")

            for k, v in settings.items():
                args[k] = v

            cmdargs = args_to_cmd(args)
            print(cmdargs)

            print("training")
            if mode == "full-ft":
                !accelerate launch train_dreambooth1.py {cmdargs}
            else:
                !accelerate launch train_dreambooth.py {cmdargs}
                
            print("done training")
    


sophia_winsell

--pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" --instance_data_dir="datasets/sophia_winsell/instance_images" --class_data_dir="datasets/sophia_winsell/class_images" --output_dir="datasets/sophia_winsell/model-full-ft_1vectors" --resolution=512 --train_batch_size=2 --gradient_accumulation_steps=1 --gradient_checkpointing --use_8bit_adam --learning_rate=1e-06 --lr_scheduler="constant" --lr_warmup_steps=50 --num_class_images=200 --max_train_steps=500 --save_steps=500 --generate_every=50 --num_vectors_per_extra_token=1 --with_prior_preservation --prior_loss_weight=1.0 --train_unet --instance_prompts="a photo of <sophia_winsell> woman" --class_prompts="a photo of a woman" --initialize_extra_tokens="<sophia_winsell>=custom" --instance_type="portrait" --generate_concept="<sophia_winsell> woman" --generate_concept_class="woman"
training
Fetching 19 files: 100%|█████████████████████| 19/19 [00:00<00:00, 17650.45it/s]
Fetching 19 files: 100%|███████████████████

100%|███████████████████████████████████████████| 32/32 [00:08<00:00,  3.88it/s]
running prompt: a bust statue of <sophia_winsell> woman head, 4k 8k 5k, olympus, canon r3, fujifilm xt3
100%|███████████████████████████████████████████| 32/32 [00:08<00:00,  3.88it/s]
done generating
35 5 7
grid saved
Steps:  20%|██▌          | 100/500 [04:07<10:59,  1.65s/it, loss=0.186, lr=1e-6]Saving the entire model

Fetching 19 files: 100%|█████████████████████| 19/19 [00:00<00:00, 17944.56it/s][A
generating
generation script called with args: {'outputdir': 'datasets/sophia_winsell/model-full-ft_1vectors', 'concept': '<sophia_winsell> woman', 'conceptclass': 'woman', 'step': 100, 'gpu': 0, 'instancetype': 'portrait'}
Some weights of the model checkpoint at datasets/sophia_winsell/model-full-ft_1vectors/text_encoder were not used when initializing CLIPTextModel: ['text_model.embeddings.token_embedding.gradmask']
- This IS expected if you are initializing CLIPTextModel from the checkpoint of a model t

100%|███████████████████████████████████████████| 32/32 [00:08<00:00,  3.89it/s]
running prompt: a portrait of <sophia_winsell> woman in anime style, manga style, studio ghibli, 2D, masterpiece, detailed
100%|███████████████████████████████████████████| 32/32 [00:08<00:00,  3.89it/s]
running prompt: a dramatic digital portrait painting of <sophia_winsell> woman, artistic, greg rutkowski, dramatic harsh light, 4k, trending on artstation
100%|███████████████████████████████████████████| 32/32 [00:08<00:00,  3.88it/s]
running prompt: an oil portrait painting of <sophia_winsell> woman in the style of vincent van gogh
100%|███████████████████████████████████████████| 32/32 [00:08<00:00,  3.88it/s]
running prompt: a bust statue of <sophia_winsell> woman head, 4k 8k 5k, olympus, canon r3, fujifilm xt3
100%|███████████████████████████████████████████| 32/32 [00:08<00:00,  3.88it/s]
done generating
35 5 7
grid saved
Steps:  50%|██████      | 250/500 [12:23<06:51,  1.65s/it, loss=0.0593, lr=1e-6

100%|███████████████████████████████████████████| 32/32 [00:08<00:00,  3.79it/s]
running prompt: a photo of <sophia_winsell> woman
100%|███████████████████████████████████████████| 32/32 [00:08<00:00,  3.89it/s]
running prompt: a portrait photo of <sophia_winsell> woman smiling, 4k, highly detailed, realistic, olympus, fujifilm
100%|███████████████████████████████████████████| 32/32 [00:08<00:00,  3.88it/s]
running prompt: a portrait of <sophia_winsell> woman in anime style, manga style, studio ghibli, 2D, masterpiece, detailed
100%|███████████████████████████████████████████| 32/32 [00:08<00:00,  3.88it/s]
running prompt: a dramatic digital portrait painting of <sophia_winsell> woman, artistic, greg rutkowski, dramatic harsh light, 4k, trending on artstation
100%|███████████████████████████████████████████| 32/32 [00:08<00:00,  3.87it/s]
running prompt: an oil portrait painting of <sophia_winsell> woman in the style of vincent van gogh
100%|███████████████████████████████████████████|

prompts: ['a photo of woman', 'a photo of <sophia_winsell> woman', 'a portrait photo of <sophia_winsell> woman smiling, 4k, highly detailed, realistic, olympus, fujifilm', 'a portrait of <sophia_winsell> woman in anime style, manga style, studio ghibli, 2D, masterpiece, detailed', 'a dramatic digital portrait painting of <sophia_winsell> woman, artistic, greg rutkowski, dramatic harsh light, 4k, trending on artstation', 'an oil portrait painting of <sophia_winsell> woman in the style of vincent van gogh', 'a bust statue of <sophia_winsell> woman head, 4k 8k 5k, olympus, canon r3, fujifilm xt3']
running prompt: a photo of woman
100%|███████████████████████████████████████████| 32/32 [00:08<00:00,  3.81it/s]
running prompt: a photo of <sophia_winsell> woman
100%|███████████████████████████████████████████| 32/32 [00:08<00:00,  3.90it/s]
running prompt: a portrait photo of <sophia_winsell> woman smiling, 4k, highly detailed, realistic, olympus, fujifilm
100%|██████████████████████████████