 ![bse_logo_textminingcourse](https://bse.eu/sites/default/files/bse_logo_small.png)

# Fine Tuning Model: DreamBooth

We implemented a Stable Diffusion model using the DreamBooth method, which is a text-to-image generation technique designed to work with a small number of data points, as in our project. DreamBooth trains and updates the full diffusion model using these limited images and associated prompts. This method allows the model to recognize specific images by associating them with a unique concept or word provided in the prompt.


The code was adapted by our team based on the following source code:
- Uysal, E. (2024, January 13). *Fine-Tuning Stable Diffusion with DreamBooth Method*. https://enessadi.medium.com/fine-tuning-stable-diffusion-with-dreambooth-method-52019b3599dd

We ran this code in Kaggle but there is also a different version that is set up to run on Google Colab named "DreamBooth_GoogleColab.ipynb" found in section 4. Fine Tuning Models of our github. 


### Setup Environment 

For instructions on how to set up the GPU for kaggle, please see the following link: https://github.com/maelysjb/Comics-GenAI/blob/main/README.md#:~:text=.gitignore-,README,-.md


### Import Packages and Install Diffusion libraries

In [None]:
# Diffuser libraries 

!pip install -qq "ipywidgets>=7,<8"
!git clone https://github.com/huggingface/diffusers
!pip install ./diffusers

In [None]:
# DreamBooth requirements & xFormers Library 

%cd /kaggle/working/diffusers/examples/dreambooth
!pip install -r requirements.txt
!pip install bitsandbytes
!pip install transformers gradio ftfy accelerate
!pip install xformers


In [None]:
!pip install torchvision --upgrade

In [None]:
import os
import shutil

# Image Display
from PIL import Image
import IPython.display as display
import matplotlib.pyplot as plt

In [None]:
# Model Training 
from diffusers import DiffusionPipeline, UNet2DConditionModel
from transformers import CLIPTextModel
import torch

In [None]:
!pip install huggingface_hub

In [None]:
# Hugging Face 
from huggingface_hub import login

### Data preparation 

In [None]:
# Folder for datasets 
# Kaggle 
%cd /kaggle/working

if os.path.exists("/kaggle/working/custom_dataset"):
    print("Removing existing custom_dataset folder")
    !rm -rf /kaggle/working/custom_dataset

print("Creating new custom_dataset folder")
!mkdir /kaggle/working/custom_dataset
!mkdir /kaggle/working/custom_dataset/class_images
!mkdir /kaggle/working/custom_dataset/instance_images

print('Custom Dataset folder is created: /kaggle/working/custom_dataset')

In [None]:
# Automatically adding the data to the folders for Kaggle 

input_path = '/kaggle/input/unicorngirl/personnage'
output_path = '/kaggle/working/custom_dataset/instance_images'

files = os.listdir(input_path)
os.makedirs(output_path, exist_ok=True)

for file in files:
    src = os.path.join(input_path, file)
    dst = os.path.join(output_path, file)
    shutil.copy(src, dst)

print("Images copied successfully to the output directory.")

In [None]:
# Preprocessing data size function 

def resize_and_crop_images(folder_path, target_size=512):
    """
    Resize the images in a folder to have a smaller edge of the specified target size and save them to a new location.

    Parameters:
    - folder_path (str): Path to the folder containing the images.
    - target_size (int): Desired size for the smaller edge (default is 512).
    """
    # Define the output folder for resized and cropped images
    output_folder = '/kaggle/working/resized_images'
    
    # Create the output folder if it doesn't exist
    os.makedirs(output_folder, exist_ok=True)
    
    # Iterate through all files in the folder
    for filename in os.listdir(folder_path):
        file_path = os.path.join(folder_path, filename)

        # Check if the file is an image
        if os.path.isfile(file_path) and filename.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp')):
            # Open the image
            image = Image.open(file_path)

            # Get the original width and height
            width, height = image.size

            # Calculate the new size while maintaining the aspect ratio
            if width <= height:
                new_width = target_size
                new_height = int(height * (target_size / width))
            else:
                new_width = int(width * (target_size / height))
                new_height = target_size

            # Resize the image
            resized_image = image.resize((new_width, new_height))

            left = (new_width - target_size) // 2
            top = (new_height - target_size) // 2
            right = (new_width + target_size) // 2
            bottom = (new_height + target_size) // 2

            # Perform the center crop
            cropped_image = resized_image.crop((left, top, right, bottom))
            
            # Save the cropped image to the output folder
            cropped_image.save(os.path.join(output_folder, filename))

In [None]:
# Plotting images function 

def show_images_in_one_row(folder_path, target_size=256):
    images = []

    for filename in os.listdir(folder_path):
        file_path = os.path.join(folder_path, filename)
        if os.path.isfile(file_path) and filename.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp')):
            img = Image.open(file_path)
            img = img.resize((target_size, int(target_size * img.size[1] / img.size[0])))
            images.append(img)

    # Display images in one row
    fig, axes = plt.subplots(1, len(images), figsize=(len(images) * 3, 3))
    for ax, img in zip(axes, images):
        ax.imshow(img)
        ax.axis('off')
    plt.show()

### Preprocessing the data 

In [None]:
#Class Images
folder_path = '/kaggle/working/custom_dataset/class_images'
if len(os.listdir(folder_path)):
  resize_and_crop_images(folder_path)
  show_images_in_one_row(folder_path)

# Instance Images
folder_path_img = '/kaggle/working/custom_dataset/instance_images'
resize_and_crop_images(folder_path_img)
show_images_in_one_row(folder_path_img)

In [None]:
# Create output folder for the generated images 
output_folder = '/kaggle/working/outputs'

if os.path.exists(output_folder):
    print("Removing existing outputs folder")
    !rm -rf $output_folder

print("Creating new outputs folder")
!mkdir $output_folder

print('Output folder is created:', output_folder)

### Login into Hugging Face account 

Replace the name for the Hugging Face token where it states: "TOKEN_FROM_HF" to the desired name. This will be your own personal Hugging Gace token in order to save a private model and dataset. 

Instructions on using Hugging Face can be found here: https://github.com/maelysjb/Comics-GenAI/blob/main/README.md#:~:text=.gitignore-,README,-.md

In [None]:
login(token="TOKEN_FROM_HF") 

### Training DreamBooth Diffusion model
Replace the name for the Hugging Face model id where it states: "DreamBooth200" to the desired name. 


In [None]:
!python /kaggle/working/diffusers/examples/dreambooth/train_dreambooth.py \
    --pretrained_model_name_or_path 'runwayml/stable-diffusion-v1-5' \
    --revision "fp16" \
    --instance_data_dir '/kaggle/working/custom_dataset/instance_images' \
    --class_data_dir '/kaggle/working/custom_dataset/class_images' \
    --instance_prompt 'An image of UnicornGirl in unicorn onesie.' \
    --class_prompt 'An image of UnicornGirl in a unicorn onesie.' \
    --with_prior_preservation \
    --prior_loss_weight 1.0 \
    --num_class_images 50 \
    --output_dir '/kaggle/working/outputs' \
    --resolution 512 \
    --train_text_encoder \
    --train_batch_size 2 \
    --sample_batch_size 2 \
    --max_train_steps 2000 \
    --checkpointing_steps 1850 \
    --gradient_accumulation_steps 1 \
    --gradient_checkpointing \
    --learning_rate 1e-6 \
    --lr_scheduler 'constant' \
    --lr_warmup_steps=0 \
    --use_8bit_adam \
    --validation_prompt 'An image of UnicornGirl in a unicorn onesie.' \
    --num_validation_images 4 \
    --mixed_precision "fp16" \
    --enable_xformers_memory_efficient_attention \
    --set_grads_to_none \
    --push_to_hub \
    --hub_model_id DreamBooth2000 
    #--report_to 'wandb'