In [None]:
import json, gc
from IPython.display import HTML
import glob
import math
from pathlib import Path

model_path  = ''
def train_model(dataset,
                instance_prompt="sks young man",
                class_prompt="a young man",
                base="ItsJayQz/SynthwavePunk-v2",
                resolution=512,
                max_train_steps=None):
    
    global pipe
    pipe = None
    gc.collect()

    if class_prompt:
        prior_dir = f"{class_prompt.replace(' ','-')}-prior"
        #class_imgs = len(get_image_files(prior_dir))
        theme_args = f'''--class_data_dir="{prior_dir}" \
          --with_prior_preservation \
          --class_prompt="{class_prompt}" \
          --num_class_images="200"
        '''
    else:
        theme_args = []
        
    instance_imgs = [f for f in Path(dataset).iterdir() if f.name != "labels.txt"]
    
    global model_path
    model_dir = f'models/{dataset}-{prior_dir}-prior-labeled-sks-cosinelr'
    model_path = model_dir
    
    if max_train_steps is None:
        max_train_steps = int((math.log10(len(instance_imgs)) * 2 + 1) * 400)
    
    !accelerate launch train_dreambooth.py \
      --pretrained_model_name_or_path={base}  \
      --instance_data_dir={dataset} \
      --output_dir={model_dir} \
      --with_prior_preservation --prior_loss_weight=1.0 \
      --save_interval={max_train_steps//5} \
      --instance_prompt="{instance_prompt}" \
      --class_prompt="{class_prompt}" \
      --resolution={resolution} \
      --train_batch_size=1 \
      --gradient_accumulation_steps=2 --gradient_checkpointing \
      --use_8bit_adam \
      --learning_rate=2e-6 \
      --lr_scheduler="constant" \
      --lr_warmup_steps=0 \
      --max_train_steps={max_train_steps} \
      {theme_args}
    
    with open(f'{model_dir}/my_metadata.json', 'w') as f:
        json.dump(dict(
                model = model_dir,
                dataset=dataset,
                instance_prompt=instance_prompt,
                class_prompt=class_prompt,
                base=base,
                max_train_steps=max_train_steps,
                resolution=resolution,
        ), f)

In [None]:
from diffusers import DiffusionPipeline
import torch
import os
import random
from pathlib import Path

def load_model(model_id):
    pipe = DiffusionPipeline.from_pretrained(model_id, custom_pipeline="./lpw_stable_diffusion.py", torch_dtype=torch.float16).to("cuda")
    pipe.modeldir = model_id
    return pipe

import os,base64

def generate(model_id, _dirname, prompt, negative_prompt=None, seed=31337, steps=50, N=9, w=512, h=512, guidance_scale=9):
    generators = [torch.Generator(device="cuda").manual_seed(seed + i*512) for i in range(N)]
    if isinstance(model_id, str):
        pipe = DiffusionPipeline.from_pretrained(model_id, custom_pipeline="./lpw_stable_diffusion.py", torch_dtype=torch.float16).to("cuda")
    else:
        pipe = model_id
        model_id = pipe.modeldir
    
    images = pipe(prompt, negative_prompt=negative_prompt, num_inference_steps=steps, guidance_scale=guidance_scale, generator=generators, num_images_per_prompt=N, width=w, height=h).images
    
    dirname = f'{Path(model_id).name}/{base64.urlsafe_b64encode(os.urandom(16)).decode("ascii")}'
    !rm -rf {dirname} && mkdir -p {dirname}
    with open(f'{dirname}/meta.json', 'w') as f:
        json.dump(dict(
            model_id = model_id,
            prompt = prompt,
            negative_prompt = negative_prompt,
            seed = seed,
            steps = steps,
            w = w,
            h = h,
            guidance_scale = guidance_scale,
        ), f)
    for i,img in enumerate(images): img.save(f'{dirname}/{i}.jpg')
    del pipe
    return HTML(''.join([f'<img style="float:left; width: 32%; margin:5px;" src="{dirname}/{i}.jpg?{random.randint(0,2**31)}" />' for i in range(N)]))

In [None]:
# import bitsandbytes
train_model("train_imgs")

In [None]:
model_path

In [None]:
for chkpt in sorted(glob.glob(model_path+'/checkpoint-*')):
    display(generate(chkpt, Path(chkpt).name, N=3,
      prompt="beautiful portrait photo of sks young man",
      negative_prompt="cartoon, 3d, (illustration:1.2), ((disfigured)), ((bad art)), ((deformed)), ((poorly drawn)), ((extra limbs)), ((close up)), ((b&w)), weird colors, blurry"))
    print("Checkpoint:", chkpt)

In [None]:
for chkpt in sorted(glob.glob(model_path+'/checkpoint-*')):
    display(generate(chkpt, Path(chkpt).name, N=3,
      prompt="sks young man, beautiful oil on canvas portrait",
      negative_prompt="cartoon, 3d, (illustration:1.2), ((disfigured)), ((bad art)), ((deformed)), ((poorly drawn)), ((extra limbs)), ((close up)), ((b&w)), weird colors, blurry"))
    print("Checkpoint:", chkpt)

In [None]:
for chkpt in sorted(glob.glob(model_path+'/checkpoint-*')):
    display(generate(chkpt, Path(chkpt).name, N=3,
          prompt="(snthwve style:1) (nvinkpunk:0.7) sks young man, (hallucinating colorful soap bubbles), by jeremy mann, by sandra chevrier, by dave mckean and richard avedon and maciej kuciara, punk rock, tank man, high detailed, 8k, sharp focus, natural lighting",
      negative_prompt="cartoon, 3d, (illustration:1.2), ((disfigured)), ((bad art)), ((deformed)), ((poorly drawn)), ((extra limbs)), ((close up)), ((b&w)), weird colors, blurry"))
    print("Checkpoint:", chkpt)

In [None]:
for chkpt in sorted(glob.glob(model_path+'/checkpoint-*')):
    display(generate(chkpt, Path(chkpt).name, N=3,
      prompt="sks young man (snthwve style:1) (nvinkpunk:0.7), (hallucinating colorful soap bubbles), by jeremy mann, by sandra chevrier, by dave mckean and richard avedon and maciej kuciara, punk rock, tank man, high detailed, 8k, sharp focus, natural lighting",
      negative_prompt="cartoon, 3d, (illustration:1.2), ((disfigured)), ((bad art)), ((deformed)), ((poorly drawn)), ((extra limbs)), ((close up)), ((b&w)), weird colors, blurry"))
    print("Checkpoint:", chkpt)

In [None]:
model_path = 'models/train_imgs-a-young-man-prior-prior-labeled-sks-cosinelr'+'/checkpoint-756'
model = load_model(model_path)

In [None]:
for w in torch.linspace(0.8,1.5,6):
    display(generate(model, f'weight-{w}', N=3, steps=100,
      prompt=f"(sks young man:{w}) (snthwve style:1) (nvinkpunk:0.7), (hallucinating colorful soap bubbles), by jeremy mann, by sandra chevrier, by dave mckean and richard avedon and maciej kuciara, punk rock, tank boy, high detailed, 8k, sharp focus, natural lighting",
      negative_prompt="cartoon, 3d, (illustration:1.2), ((disfigured)), ((bad art)), ((deformed)), ((poorly drawn)), ((extra limbs)), ((close up)), ((b&w)), weird colors, blurry"))
    print("Weight:", w)