# Fine-Tuning a Model

Let's teach our text to image generator what my dog looks like.  In case you've

### Check GPU

In [1]:
!nvidia-smi
!nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv,noheader

Thu Aug 17 19:53:04 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.105.17   Driver Version: 525.105.17   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A10G         On   | 00000000:00:1E.0 Off |                    0 |
|  0%   31C    P8    25W / 300W |      0MiB / 23028MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

## Install Requirements

In [2]:
!pip install -q --upgrade accelerate transformers ftfy
!pip install -q git+https://github.com/huggingface/diffusers


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.1.2[0m[39;49m -> [0m[32;49m23.2.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.1.2[0m[39;49m -> [0m[32;49m23.2.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [3]:
!wget https://rhods-public.s3.amazonaws.com/src/xformers-0.0.21.dev584-cp39-cp39-manylinux2014_x86_64.whl -O xformers-0.0.21.dev584-cp39-cp39-manylinux2014_x86_64.whl
!pip install xformers-0.0.21.dev584-cp39-cp39-manylinux2014_x86_64.whl

--2023-08-17 19:53:16--  https://rhods-public.s3.amazonaws.com/src/xformers-0.0.21.dev584-cp39-cp39-manylinux2014_x86_64.whl
Resolving rhods-public.s3.amazonaws.com (rhods-public.s3.amazonaws.com)... 16.182.67.57, 54.231.140.137, 54.231.234.177, ...
Connecting to rhods-public.s3.amazonaws.com (rhods-public.s3.amazonaws.com)|16.182.67.57|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 127729734 (122M) [binary/octet-stream]
Saving to: ‘xformers-0.0.21.dev584-cp39-cp39-manylinux2014_x86_64.whl’


2023-08-17 19:53:19 (62.1 MB/s) - ‘xformers-0.0.21.dev584-cp39-cp39-manylinux2014_x86_64.whl’ saved [127729734/127729734]

Processing ./xformers-0.0.21.dev584-cp39-cp39-manylinux2014_x86_64.whl
xformers is already installed with the same version as the provided wheel. Use --force-reinstall to force an installation of the wheel.

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.1.2[0m[39;49m -> [0m[32;49m23.2.1[0m


In [4]:
!pip list | grep -e torch -e torchvision -e diffusers -e accelerate -e torchvision -e transformers -e ftfy -e tensorboard -e Jinja2 -e xformers


accelerate                      0.21.0
diffusers                       0.21.0.dev0
ftfy                            6.1.1
Jinja2                          3.1.2
tensorboard                     2.11.2
tensorboard-data-server         0.6.1
tensorboard-plugin-wit          1.8.1
torch                           1.13.1
torchvision                     0.14.1
transformers                    4.31.0
xformers                        0.0.21.dev584

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.1.2[0m[39;49m -> [0m[32;49m23.2.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


## Settings

In [5]:
import os

VERSION = "optimized"
MODEL_NAME = "runwayml/stable-diffusion-v1-5"
OUTPUT_DIR = os.path.join(os.getcwd(), f"{VERSION}/stable_diffusion_weights/redhat-dog")
DATA_DIR = os.path.join(os.getcwd(), f"{VERSION}/data")
INSTANCE_DIR = os.path.join(DATA_DIR, "redhat-dog")
CLASS_DIR = os.path.join(DATA_DIR, "dog")
INSTANCE_PROMPT = "photo of a rhteddy dog"
CLASS_PROMPT = "a photo of dog"

os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(INSTANCE_DIR, exist_ok=True)

print(f"Weights will be saved at {OUTPUT_DIR}")
print(f"MODEL_NAME={MODEL_NAME}")
print(f"Training data located in {INSTANCE_DIR}")
print(f"Instance Prompt = {INSTANCE_PROMPT}")


Weights will be saved at /opt/app-root/src/text-to-image-demo/optimized/stable_diffusion_weights/redhat-dog
MODEL_NAME=runwayml/stable-diffusion-v1-5
Training data located in /opt/app-root/src/text-to-image-demo/optimized/data/redhat-dog
Instance Prompt = photo of a rhteddy dog


## Training

### Get Training Data


In [6]:
import sys
import os
import tarfile
import urllib

url = "https://rhods-public.s3.amazonaws.com/sample-data/images/redhat-dog.tar.gz"
output = os.path.join(DATA_DIR, "redhat-dog.tar.gz")
urllib.request.urlretrieve(url, output)

file = tarfile.open(output)
file.extractall(DATA_DIR)
file.close ()

### Set up the Training Job

In [7]:
!accelerate config default

Configuration already exists at /opt/app-root/src/.cache/huggingface/accelerate/default_config.yaml, will not override. Run `accelerate config` manually or pass a different `save_location`.


In [8]:
!wget -O train_dreambooth.py https://raw.githubusercontent.com/huggingface/diffusers/main/examples/dreambooth/train_dreambooth.py

--2023-08-17 19:53:25--  https://raw.githubusercontent.com/huggingface/diffusers/main/examples/dreambooth/train_dreambooth.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.111.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 55732 (54K) [text/plain]
Saving to: ‘train_dreambooth.py’


2023-08-17 19:53:26 (56.7 MB/s) - ‘train_dreambooth.py’ saved [55732/55732]



### Start Training

In [9]:
!echo $MODEL_NAME 
!echo $INSTANCE_DIR
!echo $CLASS_DIR
!echo $OUTPUT_DIR
!echo $INSTANCE_PROMPT

!accelerate launch train_dreambooth.py \
  --pretrained_model_name_or_path=$MODEL_NAME  \
  --instance_data_dir=$INSTANCE_DIR \
  --class_data_dir=$CLASS_DIR \
  --output_dir=$OUTPUT_DIR \
  --with_prior_preservation --prior_loss_weight=1.0 \
  --instance_prompt="$INSTANCE_PROMPT" \
  --class_prompt="$CLASS_PROMPT" \
  --resolution=512 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=1 \
  --learning_rate=5e-6 \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --num_class_images=200 \
  --max_train_steps=1000 \
  --enable_xformers_memory_efficient_attention

runwayml/stable-diffusion-v1-5
/opt/app-root/src/text-to-image-demo/optimized/data/redhat-dog
/opt/app-root/src/text-to-image-demo/optimized/data/dog
/opt/app-root/src/text-to-image-demo/optimized/stable_diffusion_weights/redhat-dog
photo of a rhteddy dog
A matching Triton is not available, some optimizations will not be enabled.
Error caught was: No module named 'triton'
08/17/2023 19:53:30 - INFO - __main__ - Distributed environment: NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda

Mixed precision type: no

{'requires_safety_checker'} was not found in config. Values will be initialized to default values.
Loading pipeline components...:   0%|                     | 0/6 [00:00<?, ?it/s]{'projection_class_embeddings_input_dim', 'addition_embed_type', 'num_attention_heads', 'timestep_post_act', 'upcast_attention', 'cross_attention_norm', 'transformer_layers_per_block', 'only_cross_attention', 'mid_block_only_cross_attention', 'class_embeddings_concat', 'time_embed

## Test Inference

In [10]:
import torch
from diffusers import DiffusionPipeline

pipeline = DiffusionPipeline.from_pretrained(OUTPUT_DIR)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pipeline.to(device)
pipeline.enable_xformers_memory_efficient_attention()


The history saving thread hit an unexpected error (OperationalError('database or disk is full')).History will not be written to the database.


A matching Triton is not available, some optimizations will not be enabled.
Error caught was: No module named 'triton'


OSError: Error no file named model_index.json found in directory /opt/app-root/src/text-to-image-demo/optimized/stable_diffusion_weights/redhat-dog.

In [None]:
from datetime import datetime
from torch import autocast

IMG_DIR = os.path.join(os.getcwd(), f"{VERSION}/generated-images/{date_string}")
prompt = "photo of a rhteddy dog on the beach"
negative_prompt = ""
num_samples = 6
guidance_scale = 7.5
num_inference_steps = 200
height = 512
width = 512

generator = torch.Generator(device=device)
generator.seed()

with autocast(device.type), torch.inference_mode():
    images = pipeline(
        prompt,
        height=height,
        width=width,
        negative_prompt=negative_prompt,
        num_images_per_prompt=num_samples,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        generator=generator
    ).images


date = datetime.now()
date_string = date.strftime('%Y%m%d%H%M%S')
IMG_DIR = f"./{VERSION}/generated-images/{date_string}"
os.makedirs(IMG_DIR, exist_ok=True)

for i, img in enumerate(images):
    img.save(os.path.join(IMG_DIR, f"{i}.jpg"))


In [None]:
import os
import math
import matplotlib.pyplot as plt
from PIL import Image

directory = IMG_DIR
images = []
for filename in os.listdir(directory):
    if filename.endswith('.jpg'):
        img = Image.open(os.path.join(directory, filename))
        images.append(img)

num_show = min(len(images), 12)
n_cols = 3
n_rows = math.ceil(len(images) / n_cols)
scale = 4
fig, axs = plt.subplots(n_rows, n_cols, figsize=(n_cols * scale, n_rows * scale),  gridspec_kw={'hspace': 0, 'wspace': 0})
for i in range(n_rows):
    for j in range(n_cols):
        axs[i, j].axis('off')
        x = i * n_cols + j
        if x < len(images):
            axs[i, j].imshow(images[x])
plt.show()
