# Dreambooth Stable Diffusion - Winter 2022 Edition 
This Colab is based on Shivam Shrirao's repository and has been modified to use dependencies from late 2022 but with diffusion from revision `fbdf0a17055ffa34679cb34d986fabc1296d0785` (2023-03-02).

If you prefer to use a similar layout used in the original colab, you can use [dbsd_dec_2022_original.ipynb](https://colab.research.google.com/github/yushan777/dbsd-dec-2022/blob/main/dbsd_dec_2022_original.ipynb)

___
https://github.com/yushan777/dbsd-dec-2022

https://github.com/ShivamShrirao/diffusers/tree/main/examples/dreambooth


Join the Dreambooth Discord!
https://discord.gg/wNNs2JNF7G 
___
#### Instructions:
#### Run each cell in order. Read instructions or notes within each cell.  Last few cells are not essential - mostly tools used during testing. But have left them since they may be useful.  

In [None]:
#@title 0. Check type of GPU and VRAM available.
!nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv,noheader

In [None]:
#@title 1. [Google Drive & Build Environment](https://github.com/yushan777/dbsd-dec-2022/blob/main/docs/dbsd-dec-2022-readme.md#cell-2-build-environment)

#@markdown > This may take a couple of minutes.....
#@markdown >
#@markdown > Mount Google Drive : You will be asked for permission.\ 

from google.colab import drive
from os import path

google_drive_dir = '/content/drive'

if path.exists(google_drive_dir)==False: 
  drive.mount(google_drive_dir)
  print(f'Google Drive mounted to {google_drive_dir}')
else: 
  print(f'Google Drive already mounted at {google_drive_dir}')

# =================================================================================
# =================================================================================

# uninstall existing pytorch to clean things up a bit, prob unncessary but doing it anyway
#print("Uninstalling existing Pytorch...")
#%pip -q uninstall torch torchtext torchaudio torchvision --y
%pip -q uninstall torchtext --y

print("Installing ShivamShrirao/Diffusers...")
# tested on commit fbdf0a17055ffa34679cb34d986fabc1296d0785 2023-03-02
%pip -q install git+https://github.com/ShivamShrirao/diffusers.git@fbdf0a17055ffa34679cb34d986fabc1296d0785

# get train_dreambooth.py from same revision 
!wget https://github.com/ShivamShrirao/diffusers/raw/fbdf0a17055ffa34679cb34d986fabc1296d0785/examples/dreambooth/train_dreambooth.py

# for scripts we want the latest ones so that safetensors are supported
!wget -q https://github.com/ShivamShrirao/diffusers/raw/fbdf0a17055ffa34679cb34d986fabc1296d0785/scripts/convert_diffusers_to_original_stable_diffusion.py
!wget -q https://github.com/ShivamShrirao/diffusers/raw/fbdf0a17055ffa34679cb34d986fabc1296d0785/scripts/convert_original_stable_diffusion_to_diffusers.py

# install requisite packages
%pip install -q torch==1.12.1+cu116 torchvision==0.13.1+cu116 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu116
%pip install -q -U --pre triton==2.0.0.dev20221030
%pip install omegaconf==2.3.0 # required for orig-diffusers conversion script
%pip install pytorch-lightning==1.8.5 # required for orig-diffusers conversion script
%pip install -q accelerate==0.12.0 transformers==4.26.0 ftfy==6.1.1 bitsandbytes==0.35.0 gradio natsort safetensors
%pip install -q xformers==0.0.13
%pip install -q gdown
#%pip install -q https://github.com/yushan777/xformers-wheels/releases/download/xformers-0.015.dev0-py38/xformers-0.0.15.dev0-cp38-cp38-linux_x86_64.whl

# =====================================================================================================================
# remove instances of param 'keep_fp32_wrapper=True' from file 'train_dreambooth.py'
import fileinput
filename = 'train_dreambooth.py'
with fileinput.FileInput(filename, inplace=True, backup='~bak') as file:
    for line in file:
        print(line.replace(', keep_fp32_wrapper=True', ''), end='')



#@markdown ___


In [None]:
#@title 2. [Token Word, Class Word & Class Prompt](https://github.com/yushan777/dbsd-dec-2022/blob/main/docs/dbsd-dec-2022-readme.md#cell-3-token-word-class-word--class-prompt)

# TOKEN is a unique identifier linked to the subject that you are training
TOKEN_WORD = "zwx" #@param {type:"string"}
word_list = TOKEN_WORD.split()
if len(word_list) == 0:
    TOKEN_WORD = "zwx"
elif len(word_list) > 1:
    print("too many words in TOKEN_WORD - using only first word")
    TOKEN_WORD = word_list[0]

# CLASS is a coarse class descriptor of the subject (e.g. person, man, woman, cat, dog, watch, etc.).
CLASS_WORD = "person" #@param {type:"string"}
word_list = CLASS_WORD.split()
if len(word_list) == 0:
    CLASS_WORD = "person"
elif len(word_list) > 1:
    print("too many words in CLASS_WORD - using only first word")
    CLASS_WORD = word_list[0]

# CLASS is a coarse class descriptor of the subject (e.g. person, man, woman, cat, dog, watch, etc.).
CLASS_PROMPT = "person" #@param {type:"string"}
if len(CLASS_PROMPT) == 0:
  CLASS_PROMPT = "person"

#@markdown > Example Token words : `skf, lun, whoo, olis...` \
#@markdown > Example class words : `person, man, woman, dog, cat, car` \
#@markdown > 
#@markdown > Token and Class word are used together to represent your subject being trained. \
#@markdown > Example: `photo of [zwf] [person] sitting in a cafe.` \
#@markdown > 
#@markdown > Class Prompt is used for generating class images used for regularization. It can be same as Class Word or more descriptive such as, `photo of a person`.  If you are providing your own class images, this will be ignored.

#@markdown ___


In [None]:
#@title 3. [Download / Convert Model & Set Model Path](https://github.com/yushan777/dbsd-dec-2022/blob/main/docs/dbsd-dec-2022-readme.md#cell-3-download--convert-model--set-model-path)

# =========================================================================================
# HUGGING FACE DIFFUSERS
# =========================================================================================
MODEL_NAME = ''
continue_diff_converstion = False

#@markdown Name or Path of the diffusers model to be trained on. (this can be a HuggingFace repo address or a local directory)
HUGGINGFACE_MODEL_PATH = "runwayml/stable-diffusion-v1-5" #@param {type:"string"}
#@markdown Use this field to directly download a checkpoint/safetensor model which will be converted. \
CKPT_SAFETENSOR_URL = '' #@param {type:"string"}

#@markdown > #### **If both fields are filled, then the ckpt/safetensor model will be used**

# if both empty, do nothing
# if just HF path filled, set model name
# if both filled, process ckpt and convert to diffusers
# if only ckpt filled, process ckpt and convert to diffusers

if len(HUGGINGFACE_MODEL_PATH) > 0 and len(CKPT_SAFETENSOR_URL) == 0:  #only HF filled
  MODEL_NAME = HUGGINGFACE_MODEL_PATH
elif len(HUGGINGFACE_MODEL_PATH) > 0 and len(CKPT_SAFETENSOR_URL) > 0: #both filled, use last one
  continue_diff_converstion = True
elif len(HUGGINGFACE_MODEL_PATH) == 0 and len(CKPT_SAFETENSOR_URL) > 0: #only ckpt filled
  continue_diff_converstion = True

# =========================================================================================
# CHECKPOINTS / SAFETENSORS
# =========================================================================================
if continue_diff_converstion:
  import os
  from glob import glob
  from natsort import natsorted

  googlelink_prefix = 'https://drive.google.com/file/d/'
  googlelink_suffix = '/view?usp=share_link'

  download_dir = '/content/downloads'

  if path.exists(download_dir)==False:
    os.mkdir(download_dir)

  # check if CKPT_SAFETENSOR_url is a google drive share link
  if CKPT_SAFETENSOR_URL.startswith(googlelink_prefix):
    share_id = CKPT_SAFETENSOR_URL
    print("is valid google share link")
    share_id = share_id.replace(googlelink_prefix, '')
    share_id = share_id.replace(googlelink_suffix, '')
    CKPT_SAFETENSOR_URL = f'https://drive.google.com/uc?id={share_id}'
    # gdown sucks on colab and fails for large files, we will have to use wget
    # see https://bcrf.biochem.wisc.edu/2021/02/05/download-google-drive-files-using-wget/
    !wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=$share_id' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=$share_id" --content-disposition -P "$download_dir" && rm -rf /tmp/cookies.txt
    # move file to downloads folder
  else:
    #download model file into download dir
    !wget "$CKPT_SAFETENSOR_URL" --content-disposition -P "$download_dir"

  # get the filename
  search_pattern = f'{download_dir}/*'
  file_list = natsorted(glob(f'{search_pattern}', recursive=False))
  # if file_list is not empty...
  if len(file_list)==0:
    print("Error: No files downloaded.")
  else:  
    # get last file (should be the only file)
    file_path = file_list[-1]  

    if file_path.endswith('.ckpt') or file_path.endswith('.safetensors'):    
      # filename only
      filename = os.path.basename(os.path.normpath(file_path))
      # save filename without extension
      filename_no_ext = os.path.splitext(filename)[0]
      #print(filename_no_ext)
      # if orig. filename ext. is safetensors then set parameter flag

      from_safetensors = ""
      if filename.endswith('safetensors'):
        from_safetensors = "--from_safetensors"

      # =========================================================================================
      # DIFFUSERS DIRECTORY
      # =========================================================================================
      DIFFUSERS_DIR = ""

      if len(DIFFUSERS_DIR)==0:
        DIFFUSERS_DIR = f'/content/diffusers-{filename_no_ext}'

      print("converting to diffusers... " + DIFFUSERS_DIR)


      # convert to diffusers 
      !python convert_original_stable_diffusion_to_diffusers.py --checkpoint_path "$file_path" --dump_path "$DIFFUSERS_DIR" $from_safetensors

      # set the model name/path
      MODEL_NAME=DIFFUSERS_DIR

      print("Model converted to diffusers : " + f'{DIFFUSERS_DIR}')
    else:
      print("Error : File downloaded is not a ckpt or safetensors model file.")

#@markdown > Example HF paths: \
#@markdown > `runwayml/stable-diffusion-v1-5` (HuggingFace)<br>
#@markdown > `drive/MyDrive/diffusers-models/sd15` (local) \
#@markdown >local path can be relative as shown above or absolute as shown in the field.  \
#@markdown >`/content` is this colab's root \ 


#@markdown > Example URLs:<br>
#@markdown > `https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt` \
#@markdown >`https://civitai.com/api/download/models/6987?type=Model&format=PickleTensor` \
#@markdown >`https://civitai.com/api/download/models/6987?type=Model&format=SafeTensor` \
#@markdown >`https://drive.google.com/file/d/1JEZCyW36ziz9Fn482MUG8T0_2P4FGG-/view?usp=share_link` _(google drive share link)_ \
#@markdown >

#@markdown ___


In [None]:
#@title 4. [Instance, Class, Output Directory, Concepts List Settings](https://github.com/yushan777/dbsd-dec-2022/blob/main/docs/dbsd-dec-2022-readme.md#cell-4-instance-class-output-directory-concepts-list-settings)

# =========================================================================================
# INSTANCE IMAGE DIRECTORY 
# =========================================================================================
INSTANCE_DIR = 'training_images/zwx' #@param {type:"string"}

if len(INSTANCE_DIR) == 0: 
  INSTANCE_DIR = f'/content/training_images/{TOKEN_WORD}' 
else:
  if INSTANCE_DIR.startswith('/content/')==False:
   INSTANCE_DIR = '/content/' + f'{INSTANCE_DIR}'

# =========================================================================================
# CLASS IMAGE DIRECTORY
# =========================================================================================
CLASS_DIR =  '/content/class_images' #@param {type:"string"}
if len(CLASS_DIR) == 0: 
  CLASS_DIR = f'/content/class_images/{CLASS_WORD}'
else:
  if CLASS_DIR.startswith('/content/')==False:
   CLASS_DIR = '/content/' + f'{CLASS_DIR}'

# =========================================================================================
# SAVE OUTPUT DIRECTORY TO GOOGLE DRIVE?
# =========================================================================================
from google.colab import drive
from os import path

google_drive_dir = '/content/drive'

SAVE_TO_GOOGLE_DRIVE = False #@param {type:"boolean"}

if SAVE_TO_GOOGLE_DRIVE==True:
  if path.exists(google_drive_dir)==False: 
    drive.mount('google_drive_dir')
    print(f'Google Drive mounted to {google_drive_dir}')
  else:
    print(f'Google Drive already mounted at {google_drive_dir}')
    
# =========================================================================================
# OUTPUT DIR PATH
# =========================================================================================
OUTPUT_DIR = "stable_diffusion_weights/zwx" #@param {type:"string"}

if SAVE_TO_GOOGLE_DRIVE:
  if len(OUTPUT_DIR)==0:
    OUTPUT_DIR = f'{google_drive_dir}' + "/MyDrive/" + f'stable_diffusion_weights/{TOKEN_WORD}'
  else:
    OUTPUT_DIR = f'{google_drive_dir}' + "/MyDrive/" + OUTPUT_DIR
else:
  if len(OUTPUT_DIR)==0:
    OUTPUT_DIR = "/content/" + f'stable_diffusion_weights/{TOKEN_WORD}'
  else:
    if OUTPUT_DIR.startswith('/content/') == False:
      OUTPUT_DIR = '/content/' + f'{OUTPUT_DIR}'    

print(f"[*] Weights will be saved at {OUTPUT_DIR}")

!mkdir -p $OUTPUT_DIR

# =========================================================================================
# CONCEPTS LIST
# =========================================================================================
# variables used so far are for one concept / subject only
# they are not essential as you can just type in literal strings as 
# shown in the commented-out concepts below
# if you choose to do multi-concepts, you must create the directories manually 
# and enter the correct paths

INSTANCE_PROMPT = f'{TOKEN_WORD} {CLASS_WORD}'
#INSTANCE_PROMPT2 = f'{TOKEN_WORD2} {CLASS_WORD2}'
#INSTANCE_PROMPT3 = f'{TOKEN_WORD3} {CLASS_WORD3}'

# You can also add multiple concepts here. 
# Try tweaking `--max_train_steps` accordingly the more concepts you have.
concepts_list = [
    {
        "instance_prompt":      f'{INSTANCE_PROMPT}',
        "class_prompt":         f'{CLASS_PROMPT}',
        "instance_data_dir":    f'{INSTANCE_DIR}',
        "class_data_dir":       f'{CLASS_DIR}'
    },
#    {
#        "instance_prompt":      "skf person",
#        "class_prompt":         "person",
#        "instance_data_dir":    "/content/training_images/skf",
#        "class_data_dir":       "/content/drive/MyDrive/class_images/SD1-5/person-ddim"
#    },
#     {
#         "instance_prompt":      "ukj dog",
#         "class_prompt":         "dog",
#         "instance_data_dir":    "/content/training_images/ukj",
#         "class_data_dir":       "/content/drive/MyDrive/class_images/SD1-5/dog-ddim"
#     }
]

import json
import os
# create an instance directory for each concept's training images
for c in concepts_list:
    os.makedirs(c["instance_data_dir"], exist_ok=True)

# create the concepts_list.json file
with open("concepts_list.json", "w") as f:
    json.dump(concepts_list, f, indent=4)

# =========================================================================================





#@markdown > **INSTANCE_DIR**: Directory for instance (training) images. Leave blank for default. \
#@markdown > Default is `training_images/{TOKEN_WORD}` \

#@markdown > **CLASS_DIR**: Directory for class images. Leave blank for default. \
#@markdown > Default is `class_images/{CLASS_WORD}` \
#@markdown > When training starts, if no class images exist in this directory then they will be created then (slower). \


#@markdown > **SAVE_TO_GOOGLE_DRIVE**: Save trained models directly to google drive. \
#@markdown > If you are saving multiple models at intervals as well as converting to ckpts or safetensors, you will need a lot of storage, so this is off by default.
#@markdown > You can selectively save your model(s) to Google Drive after training.

#@markdown > **OUTPUT_DIR**: Enter the directory to save trained model(s)okay in. Leave empty for default. \
#@markdown > _Default is : \
#@markdown > `stable_diffusion_weights/{TOKEN_WORD}` \
#@markdown > If you are saving to Google Drive then it will be: \
#@markdown > `drive/MyDrive/stable_diffusion_weights/{TOKEN_WORD}`_
#@markdown ___


In [None]:
#@title 5. Upload Instance (Training) Images 🌌🌄🏞️

#@markdown If the instance directory defined in [Cell 4](#scrollTo=-APTAioh6o5A) does not already contain images, manually upload instance images to the folder. \ 
#@markdown You can use the file manager on the left panel to upload \
print("drag and drop your images into the folder : " + f'{INSTANCE_DIR}')

#@markdown ___

Training Parameter Combinations

Use the table below to choose the best flags based on your memory and speed requirements. Tested on Tesla T4 GPU.


| `fp16` | `train_batch_size` | `gradient_accumulation_steps` | `gradient_checkpointing` | `use_8bit_adam` | GB VRAM usage | Speed (it/s) |
| ---- | ------------------ | ----------------------------- | ----------------------- | --------------- | ---------- | ------------ |
| fp16 | 1                  | 1                             | TRUE                    | TRUE            | 9.92       | 0.93         |
| no   | 1                  | 1                             | TRUE                    | TRUE            | 10.08      | 0.42         |
| fp16 | 2                  | 1                             | TRUE                    | TRUE            | 10.4       | 0.66         |
| fp16 | 1                  | 1                             | FALSE                   | TRUE            | 11.17      | 1.14         |
| no   | 1                  | 1                             | FALSE                   | TRUE            | 11.17      | 0.49         |
| fp16 | 1                  | 2                             | TRUE                    | TRUE            | 11.56      | 1            |
| fp16 | 2                  | 1                             | FALSE                   | TRUE            | 13.67      | 0.82         |
| fp16 | 1                  | 2                             | FALSE                   | TRUE            | 13.7       | 0.83          |
| fp16 | 1                  | 1                             | TRUE                    | FALSE           | 15.79      | 0.77         |


Add `--gradient_checkpointing` flag for around 9.92 GB VRAM usage.

remove `--use_8bit_adam` flag for full precision. Requires 15.79 GB with `--gradient_checkpointing` else 17.8 GB.

remove `--train_text_encoder` flag to reduce memory usage further, degrades output quality.

In [None]:
#@title 6. [Training!](https://github.com/yushan777/dbsd-dec-2022/blob/main/docs/dbsd-dec-2022-readme.md#cell-6-training)

NUM_CLASS_IMAGES = 200 #@param {type:"integer"}
MAX_TRAIN_STEPS = 2000 #@param {type:"integer"}
SAVE_INTERVAL = 500 #@param {type:"integer"}
SAVE_MIN_STEPS = 500 #@param {type:"integer"}
SAMPLE_PROMPT = f'a photo of {TOKEN_WORD} {CLASS_WORD}'

!accelerate launch train_dreambooth.py \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --pretrained_vae_name_or_path="stabilityai/sd-vae-ft-mse" \
  --output_dir=$OUTPUT_DIR \
  --revision="fp16" \
  --with_prior_preservation --prior_loss_weight=1.0 \
  --seed=1337 \
  --resolution=512 \
  --train_batch_size=1 \
  --train_text_encoder \
  --mixed_precision="fp16" \
  --use_8bit_adam \
  --gradient_accumulation_steps=1 \
  --learning_rate=1e-6 \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --num_class_images="$NUM_CLASS_IMAGES" \
  --sample_batch_size=4 \
  --max_train_steps="$MAX_TRAIN_STEPS" \
  --save_interval="$SAVE_INTERVAL" \
  --save_min_steps="$SAVE_MIN_STEPS" \
  --save_sample_prompt="$SAMPLE_PROMPT" \
  --concepts_list="concepts_list.json"

# Reduce the `--save_interval` to lower than `--max_train_steps` to save weights from intermediate steps.
# `--save_sample_prompt` can be same as `--instance_prompt` to generate intermediate samples (saved along with weights in samples directory).

#@markdown > **NUM_CLASS_IMAGES**: Number of Class images to generate and/or use. \
#@markdown > **MAX_TRAIN_STEPS**: Maximum number of steps to train. \
#@markdown > **SAVE_INTERVAL**: Save weights at every N steps. (set this to same as or greater than MAX_TRAIN_STEPS to only have one set of weights saved)\ 
#@markdown > **SAVE_MIN_STEPS**: Start saving weights at and after N steps. \

#@markdown ___


In [None]:
#@title 7. Generate Grid of Preview Images Generated During Training (Optional)
#@markdown Run to generate a grid of preview images from all saved weights.
import os
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

weights_folder = OUTPUT_DIR
folders = sorted([f for f in os.listdir(weights_folder) if f != "0"], key=lambda x: int(x))

row = len(folders)
col = len(os.listdir(os.path.join(weights_folder, folders[0], "samples")))
scale = 4
fig, axes = plt.subplots(row, col, figsize=(col*scale, row*scale), gridspec_kw={'hspace': 0, 'wspace': 0})

for i, folder in enumerate(folders):
    folder_path = os.path.join(weights_folder, folder)
    image_folder = os.path.join(folder_path, "samples")
    images = [f for f in os.listdir(image_folder)]
    for j, image in enumerate(images):
        if row == 1:
            currAxes = axes[j]
        else:
            currAxes = axes[i, j]
        if i == 0:
            currAxes.set_title(f"Image {j}")
        if j == 0:
            currAxes.text(-0.1, 0.5, folder, rotation=0, va='center', ha='center', transform=currAxes.transAxes)
        image_path = os.path.join(image_folder, image)
        img = mpimg.imread(image_path)
        currAxes.imshow(img, cmap='gray')
        currAxes.axis('off')
        
plt.tight_layout()
plt.savefig('grid.png', dpi=72)

#@markdown ___


In [None]:
#@title 8a. [Convert All Weights To ckpt / safetensors.](https://github.com/yushan777/dbsd-dec-2022/blob/main/docs/dbsd-dec-2022-readme.md#cell-8a-convert-all-weights-to-ckpt--safetensors) (To convert just one weight use cell 8b.)
import os
from glob import glob
from natsort import natsorted
# from os.path import exists

MODEL_NAME_PREFIX = "tom_zwx" #@param {type:"string"}
MODEL_FORMAT = "ckpt" #@param ["ckpt", "safetensors"] {type:"string"}

use_safetensors = ""
if MODEL_FORMAT == 'safetensors': 
  use_safetensors = "--use_safetensors"

half_arg = ""
#@markdown  Whether to convert to fp16, (reduces filesize down to 2GB).
fp16 = True #@param {type: "boolean"}
if fp16: 
  half_arg = "--half"

# check dir
#print(OUTPUT_DIR)

search_pattern = OUTPUT_DIR + '/*/'
folder_list = natsorted(glob(f'{search_pattern}', recursive=False))
#print(len(folder_list))

for folderpath in folder_list:
    # get the last part of the path
    step_val = os.path.basename(os.path.normpath(folderpath))
    print("folderpath = " + folderpath)
    #print(step_val)
    if int(step_val) > 0:
      ckpt_path = folderpath + f'{MODEL_NAME_PREFIX}_' + f'{step_val}'
      if MODEL_FORMAT == 'ckpt':
        ckpt_path += '.ckpt'
      else:
        ckpt_path += '.safetensors'

      !python convert_diffusers_to_original_stable_diffusion.py --model_path $folderpath --checkpoint_path $ckpt_path $half_arg $use_safetensors
      

      print("checkpoint saved : " + ckpt_path)

print("Complete.")
print("You can now move the checkpoints to your google drive or download them directly.")

#@markdown ___


In [None]:
#@title 8b. [Convert Specific Weight To ckpt / safetensors](https://github.com/yushan777/dbsd-dec-2022/blob/main/docs/dbsd-dec-2022-readme.md#cell-8b-convert-specific-weight-to-ckpt--safetensors)
import os

#@markdown Specify the weights directory to use (leave blank for last saved weight.
WEIGHTS_DIR = "/content/stable_diffusion_weights/zwx/800/" #@param {type:"string"}
if WEIGHTS_DIR == "":
    from natsort import natsorted
    from glob import glob
    import os
    WEIGHTS_DIR = natsorted(glob(OUTPUT_DIR + os.sep + "*"))[-1]
print(f"[*] WEIGHTS_DIR={WEIGHTS_DIR}")

model_name_prefix = "tom_zwx" #@param {type:"string"}
model_format = "safetensors" #@param ["ckpt", "safetensors"] {type:"string"}

use_safetensors = ""
if model_format == 'safetensors': 
  use_safetensors = "--use_safetensors"

step_val = os.path.basename(os.path.normpath(WEIGHTS_DIR))
ckpt_path = WEIGHTS_DIR + model_name_prefix + f'_{step_val}' 
if model_format == 'ckpt':
  ckpt_path += '.ckpt'
else:
  ckpt_path += '.safetensors'
print(ckpt_path)
half_arg = ""
#@markdown  Whether to convert to fp16, takes half the space (2GB).
fp16 = True #@param {type: "boolean"}
if fp16:
    half_arg = "--half"


!python convert_diffusers_to_original_stable_diffusion.py --model_path $WEIGHTS_DIR  --checkpoint_path $ckpt_path $half_arg $use_safetensors
print(f"[*] Converted ckpt saved at {ckpt_path}")

#@markdown ___

In [None]:
#@title 9. List Weight Directories
from glob import glob
from natsort import natsorted
search_pattern = OUTPUT_DIR + '/*/'
folder_list = natsorted(glob(f'{search_pattern}', recursive=False))
#print(len(folder_list))

for folderpath in folder_list:
  step_val = os.path.basename(os.path.normpath(folderpath))
  if int(step_val) > 0:
    print(folderpath)

#@markdown ___


In [None]:
#@title 10. (Optional) Inference - Image Generation
import torch
from torch import autocast
from diffusers import StableDiffusionPipeline, DDIMScheduler
from IPython.display import display

#@markdown Specify the weights directory to use for inference. (Use Cell [10b](#scrollTo=zPpKtMC_mz7t) to show a list of directories)
WEIGHTS_DIR = "/content/stable_diffusion_weights/zwx/900/" #@param {type:"string"}

model_path = WEIGHTS_DIR             # If you want to use previously trained model saved in gdrive, replace this with the full path of model in gdrive

pipe = StableDiffusionPipeline.from_pretrained(model_path, safety_checker=None, torch_dtype=torch.float16).to("cuda")
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.enable_xformers_memory_efficient_attention()
g_cuda = None

g_cuda = torch.Generator(device='cuda')
seed = 52362 #@param {type:"number"}
g_cuda.manual_seed(seed)


prompt = "photo of zwx person (be sure to replace zwx with your own token word here)" #@param {type:"string"}
negative_prompt = "" #@param {type:"string"}
num_samples = 2 #@param {type:"number"}
guidance_scale = 7 #@param {type:"number"}
num_inference_steps = 24 #@param {type:"number"}
height = 512 #@param {type:"number"}
width = 512 #@param {type:"number"}

with autocast("cuda"), torch.inference_mode():
    images = pipe(
        prompt,
        height=height,
        width=width,
        negative_prompt=negative_prompt,
        num_images_per_prompt=num_samples,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        generator=g_cuda
    ).images

for img in images:
    display(img)

#@markdown ___

In [None]:
#@title (Optional) Delete diffuser and old weights and only keep the ckpt / safetensors to free up drive space.

#@markdown [ ! ] Caution, Only execute if you are sure u want to delete the diffuser format weights and only use the ckpt.
import shutil
from glob import glob
import os
for f in glob(OUTPUT_DIR+os.sep+"*"):
    if f != WEIGHTS_DIR:
        shutil.rmtree(f)
        print("Deleted", f)
for f in glob(WEIGHTS_DIR+"/*"):
    if not f.endswith(".ckpt") or not f.endswith(".safetensors") or not f.endswith(".json"):
        try:
            shutil.rmtree(f)
        except NotADirectoryError:
            continue
        print("Deleted", f)

#@markdown ___

In [None]:
#@title Delete a Directory and its Contents

#@markdown Deleting a folder via the file manager can't be done if the folder isn't empty. This will delete a folder and everthing below it recursively.<br><br>

#@markdown Specify the directory to be deleted. (Be sure it is the correct one).  colab's root folder is `/content/` \
#@markdown Easiest way is to right-click on the folder and copy its path. 

dir = "/content/stable_diffusion_weights/zwx" #@param {type:"string"}

!rm -rf "$dir"

print("you may need to refresh the file manager view if the folder is still visible after deleting.")

#@markdown ___


In [None]:
#@title Empty Trash

!rm -rf ~/.local/share/Trash/

#@markdown ___

In [None]:
#@title Free runtime memory
exit()