# Stable Diffusion Finetuning
Want to generate images that look a somewhat like you? Yes, you're at the right place. The fine-tuning takes about 15 minutes which would enable us to generate loads of different style images based on the data used for fine-tuning.

- Create a folder ```train_imgs``` and add 5-10 images; could be portraits, close-ups or full body images.
- Don't forget to change the ```instance_prompt``` and ```class_prompt``` in ```train_model``` function in the following cell to let the model know what it has to basically learn corresponding to the images we provide.

In [None]:
import json, gc
from IPython.display import HTML
import glob
import math
from pathlib import Path

model_path  = ''
def train_model(dataset,
                instance_prompt="sks young woman",
                class_prompt="a young woman",
                base="ItsJayQz/SynthwavePunk-v2",
                resolution=512,
                max_train_steps=None):
    
    global pipe
    pipe = None
    gc.collect()

    if class_prompt:
        prior_dir = f"{class_prompt.replace(' ','-')}-prior"
        #class_imgs = len(get_image_files(prior_dir))
        theme_args = f'''--class_data_dir="{prior_dir}" \
          --with_prior_preservation \
          --class_prompt="{class_prompt}" \
          --num_class_images="200"
        '''
    else:
        theme_args = []
        
    instance_imgs = [f for f in Path(dataset).iterdir() if f.name != "labels.txt"]
    
    global model_path
    model_dir = f'models/{dataset}-{prior_dir}-prior-labeled-sks-cosinelr'
    model_path = model_dir
    
    if max_train_steps is None:
        max_train_steps = int((math.log10(len(instance_imgs)) * 2 + 1) * 400)
    
    !accelerate launch train_dreambooth.py \
      --pretrained_model_name_or_path={base}  \
      --instance_data_dir={dataset} \
      --output_dir={model_dir} \
      --with_prior_preservation --prior_loss_weight=1.0 \
      --save_interval={max_train_steps//5} \
      --instance_prompt="{instance_prompt}" \
      --class_prompt="{class_prompt}" \
      --resolution={resolution} \
      --train_batch_size=1 \
      --gradient_accumulation_steps=2 --gradient_checkpointing \
      --use_8bit_adam \
      --learning_rate=2e-6 \
      --lr_scheduler="constant" \
      --lr_warmup_steps=0 \
      --max_train_steps={max_train_steps} \
      {theme_args}
    
    with open(f'{model_dir}/my_metadata.json', 'w') as f:
        json.dump(dict(
                model = model_dir,
                dataset=dataset,
                instance_prompt=instance_prompt,
                class_prompt=class_prompt,
                base=base,
                max_train_steps=max_train_steps,
                resolution=resolution,
        ), f)

In [None]:
from diffusers import DiffusionPipeline
import torch
import os
import random
from pathlib import Path

def load_model(model_id):
    """
    Loads the model into a DiffusionPipeline.
    """
    pipe = DiffusionPipeline.from_pretrained(model_id, custom_pipeline="./lpw_stable_diffusion.py", torch_dtype=torch.float16).to("cuda")
    pipe.modeldir = model_id
    return pipe

import os,base64

def generate(model_id, _dirname, prompt, negative_prompt=None, seed=31337, steps=50, N=9, w=512, h=512, guidance_scale=9):
    """
    Generates N number of images from the given model.
    """
    generators = [torch.Generator(device="cuda").manual_seed(seed + i*512) for i in range(N)]
    if isinstance(model_id, str):
        pipe = DiffusionPipeline.from_pretrained(model_id, custom_pipeline="./lpw_stable_diffusion.py", torch_dtype=torch.float16).to("cuda")
    else:
        pipe = model_id
        model_id = pipe.modeldir
    
    images = pipe(prompt, negative_prompt=negative_prompt, num_inference_steps=steps, guidance_scale=guidance_scale, generator=generators, num_images_per_prompt=N, width=w, height=h).images
    
    dirname = f'{Path(model_id).name}/{base64.urlsafe_b64encode(os.urandom(16)).decode("ascii")}'
    !rm -rf {dirname} && mkdir -p {dirname}
    with open(f'{dirname}/meta.json', 'w') as f:
        json.dump(dict(
            model_id = model_id,
            prompt = prompt,
            negative_prompt = negative_prompt,
            seed = seed,
            steps = steps,
            w = w,
            h = h,
            guidance_scale = guidance_scale,
        ), f)
    for i,img in enumerate(images): img.save(f'{dirname}/{i}.jpg')
    del pipe
    return HTML(''.join([f'<img style="float:left; width: 32%; margin:5px;" src="{dirname}/{i}.jpg?{random.randint(0,2**31)}" />' for i in range(N)]))

## Training
- This should take around ~15 minutes to fine-tune on your custom images
- Go and grab a coffee; you deserve it!

In [None]:

train_model("train_imgs")

## Best checkpoint?
- We periodically save checkpoints while training.
- Next, we try and figure it which checkpoint works best and generates images that look more like you.

In [None]:
for chkpt in sorted(glob.glob(model_path+'/checkpoint-*')):
    display(generate(chkpt, Path(chkpt).name, N=3,
      prompt="beautiful portrait photo of sks young woman",
      negative_prompt="cartoon, 3d, (illustration:1.2), ((disfigured)), ((bad art)), ((deformed)), ((poorly drawn)), ((extra limbs)), ((close up)), ((b&w)), weird colors, blurry"))
    print("Checkpoint:", chkpt)

In [None]:
for chkpt in sorted(glob.glob(model_path+'/checkpoint-*')):
    display(generate(chkpt, Path(chkpt).name, N=3,
      prompt="sks young woman, beautiful oil on canvas portrait",
      negative_prompt="cartoon, 3d, (illustration:1.2), ((disfigured)), ((bad art)), ((deformed)), ((poorly drawn)), ((extra limbs)), ((close up)), ((b&w)), weird colors, blurry"))
    print("Checkpoint:", chkpt)

In [None]:
for chkpt in sorted(glob.glob(model_path+'/checkpoint-*')):
    display(generate(chkpt, Path(chkpt).name, N=3,
          prompt="(snthwve style:1) (nvinkpunk:0.7) sks young woman, (hallucinating colorful soap bubbles), by jeremy mann, by sandra chevrier, by dave mckean and richard avedon and maciej kuciara, punk rock, tank woman, high detailed, 8k, sharp focus, natural lighting",
      negative_prompt="cartoon, 3d, (illustration:1.2), ((disfigured)), ((bad art)), ((deformed)), ((poorly drawn)), ((extra limbs)), ((close up)), ((b&w)), weird colors, blurry"))
    print("Checkpoint:", chkpt)

In [None]:
for chkpt in sorted(glob.glob(model_path+'/checkpoint-*')):
    display(generate(chkpt, Path(chkpt).name, N=3,
      prompt="sks young woman (snthwve style:1) (nvinkpunk:0.7), (hallucinating colorful soap bubbles), by jeremy mann, by sandra chevrier, by dave mckean and richard avedon and maciej kuciara, punk rock, tank girl, high detailed, 8k, sharp focus, natural lighting",
      negative_prompt="cartoon, 3d, (illustration:1.2), ((disfigured)), ((bad art)), ((deformed)), ((poorly drawn)), ((extra limbs)), ((close up)), ((b&w)), weird colors, blurry"))
    print("Checkpoint:", chkpt)

The above cells that generate images with all the saved checkpoint should help you figure out the best one. Add the checkpoint number to the below cell to use it in all the subsequent image generations.


In [None]:
# add the best checkpoint here
best_checkpoint = 'checkpoint-'

In [None]:
model_path = model_path+'/'+best_checkpoint
model = load_model(model_path)

## Weight factor
The last thing to figure out is how much weight to assign to the prompt that we used to train, the next generates images with different weights and shows what works best

In [None]:
for w in torch.linspace(0.8,1.5,6):
    display(generate(model, f'weight-{w}', N=3, steps=100,
      prompt=f"(sks young woman:{w}) (snthwve style:1) (nvinkpunk:0.7), (hallucinating colorful soap bubbles), by jeremy mann, by sandra chevrier, by dave mckean and richard avedon and maciej kuciara, punk rock, tank girl, high detailed, 8k, sharp focus, natural lighting",
      negative_prompt="cartoon, 3d, (illustration:1.2), ((disfigured)), ((bad art)), ((deformed)), ((poorly drawn)), ((extra limbs)), ((close up)), ((b&w)), weird colors, blurry"))
    print("Weight:", w)

## Start Painting
Feel free to play with parameters that effect the generated images:
- Add/remove words from the prompt, depending on how you want to improvise.
- Change the number of steps. Generally, steps=50 works good enough.
- Changing the seed also changes the results so feel free to change it any random number you like. It could be your lucky number as well.

In [None]:
# start with a simple prompt
generate(model, "self", "sks young woman snthwve style nvinkpunk, high detailed, 8k")

In [None]:
# improvise on the prompt
# add the weight which looked best in the above cell
# and increase number of steps
generate(model, "self", "(sks young woman:0.94) snthwve style nvinkpunk, high detailed, 8k", steps=100)

In [None]:
# add more details to the prompt
# add a negative prompt to remove artifacts
generate(model, "self",
         prompt="(sks young woman:0.94) as a beautiful god, snthwve style nvinkpunk (symmetry:1.1) (portrait of floral:1.05), (assassins creed style:0.8), pink and gold and opal color scheme, beautiful intricate filegrid facepaint, intricate, elegant, highly detailed, digital painting, artstation, concept art, smooth, sharp focus, illustration, art by greg rutkowski and alphonse mucha, 8k",
         negative_prompt="cartoon, 3d, ((disfigured)), ((bad art)), ((deformed)), ((poorly drawn)), ((extra limbs)), ((close up)), ((b&w)), weird colors, blurry")

## Some examples

In [None]:
generate(model, "self", steps=70,
         prompt="(sks young woman:1.08) as a beautiful god, snthwve style nvinkpunk (symmetry:1.1) (portrait of floral:1.05), beautiful intricate filegrid facepaint, intricate, elegant, highly detailed, digital painting, artstation, concept art, smooth, sharp focus, illustration, art by greg rutkowski and alphonse mucha, 8k",
         negative_prompt="cartoon, 3d, ((disfigured)), ((bad art)), ((deformed)), ((poorly drawn)), ((extra limbs)), ((close up)), ((b&w)), weird colors, blurry")

In [None]:
generate(model_path, "self", steps=70,
         prompt="(sks young woman:1.08), style of joemadureira (nvinkpunk:0.7) snthwve style award winning sexy half body portrait in a jacket and cargo pants with ombre navy blue teal hairstyle with head in motion and hair flying, paint splashes, splatter, outrun, vaporware, shaded flat illustration, digital art, trending on artstation, highly detailed, fine detail, intricate",
         negative_prompt="cartoon, ((closeup)), ((disfigured)), ((deformed)), ((poorly drawn)), ((extra limbs)), blurry")

In [None]:
generate(model, "self",
         prompt="snthwve style nvinkpunk close up portrait of a punk Julita young woman, punk rock, roller derby, tank girl, bubblegum, edgy, dangerous, high detailed, 8k, by jeremy mann, by sandra chevrier, by dave mckean and richard avedon and maciej kuciara",
         negative_prompt="cartoon, ((closeup)), ((disfigured)), ((deformed)), ((poorly drawn)), ((extra limbs)), blurry")

In [None]:
generate(model, "self", steps=70,
         prompt="half body portrait of sks young woman as a beautiful god, snthwve style nvinkpunk (symmetry:1.1), (assassins creed style:0.8), beautiful intricate filegrid facepaint, intricate, elegant, highly detailed, digital painting, artstation, concept art, smooth, sharp focus, illustration, art by greg rutkowski and alphonse mucha, 8k",
         negative_prompt="cartoon, 3d, ((disfigured)), ((bad art)), ((deformed)), ((poorly drawn)), ((extra limbs)), ((close up)), ((b&w)), weird colors, blurry")

### Mixing styles from above pantings

In [None]:
generate(model, "self",
         prompt="snthwve style nvinkpunk (symmetry:1.1) (half body portrait:1.1) of Julita young woman as a beautiful god, (assassins creed style:0.8), beautiful intricate filegrid facepaint, intricate, elegant, highly detailed, digital painting, artstation, concept art, smooth, sharp focus, illustration, art by greg rutkowski and alphonse mucha, 8k",
         negative_prompt="cartoon, 3d, ((disfigured)), ((bad art)), ((deformed)), ((poorly drawn)), ((extra limbs)), ((close up)), ((b&w)), weird colors, blurry")

In [None]:
generate(model, "self", steps=80,
         prompt="(nvinkpunk:0.7) snthwve style award winning sexy half body portrait of (sks young woman:1.1) in a jacket and cargo pants with ombre navy blue teal hairstyle with head in motion and hair flying, paint splashes, splatter, outrun, vaporware, shaded flat illustration, digital art, trending on artstation, highly detailed, fine detail, intricate",
         negative_prompt="cartoon, 3d, ((disfigured)), ((bad art)), ((deformed)), ((poorly drawn)), ((extra limbs)), ((close up)), ((b&w)), weird colors, blurry")

In [None]:
generate(model, "self",
         "snthwve style nvinkpunk drunken beautiful Julita young woman, (hallucinating colorful soap bubbles), by jeremy mann, by sandra chevrier, by dave mckean and richard avedon and maciej kuciara, punk rock, tank girl, high detailed, 8k",
        negative_prompt="cartoon, 3d, ((disfigured)), ((bad art)), ((deformed)), ((poorly drawn)), ((extra limbs)), ((close up)), ((b&w)), weird colors, blurry")

In [None]:
generate(model, "self", steps=60, guidance_scale=7,
         prompt="snthwve style nvinkpunk (symmetry:1.1) (portrait of floral:1.05) a sks young woman as a beautiful god, (assassins creed style:0.8), pink and gold and opal color scheme, beautiful intricate filegrid facepaint, intricate, elegant, highly detailed, digital painting, artstation, concept art, smooth, sharp focus, illustration, art by greg rutkowski and alphonse mucha, 8k",
         negative_prompt="cartoon, 3d, ((disfigured)), ((bad art)), ((deformed)), ((poorly drawn)), ((extra limbs)), ((close up)), ((b&w)), weird colors, blurry")

## Already have a prompt in mind?
Paint your own

In [None]:
generate(model, "self",
         prompt=""
         negative_prompt="cartoon, 3d, ((disfigured)), ((bad art)), ((deformed)), ((poorly drawn)), ((extra limbs)), ((close up)), ((b&w)), weird colors, blurry")