# Fine-Tuning a Model

Let's teach our text to image generator what my dog looks like.

![redhat dog](https://rhods-public.s3.amazonaws.com/sample-data/images/redhat-dog-small.jpg)

### Check GPU

In [None]:
!nvidia-smi
!nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv,noheader

## Install Requirements

In [None]:
!pip install -q --upgrade accelerate transformers ftfy
!pip install -q git+https://github.com/huggingface/diffusers

In [None]:
!pip list | grep -e torch -e torchvision -e diffusers -e accelerate -e torchvision -e transformers -e ftfy -e tensorboard -e Jinja2


## Settings

In [None]:
import os
from datetime import datetime

date = datetime.now()
date_string = date.strftime("%Y%m%d%H%M%S")
VERSION = f"simple-{date_string}"

MODEL_NAME = os.environ.get("MODEL_NAME", "runwayml/stable-diffusion-v1-5")
OUTPUT_DIR = os.path.join(os.getcwd(), f"{VERSION}/stable_diffusion_weights/redhat-dog")
DATA_DIR = os.path.join(os.getcwd(), f"{VERSION}/data")
INSTANCE_DATA_URL = os.environ.get("INSTANCE_DATA_URL", "https://rhods-public.s3.amazonaws.com/sample-data/images/redhat-dog.tar.gz")
INSTANCE_DIR = os.path.join(DATA_DIR, "instance_dir")
INSTANCE_PROMPT = os.environ.get("INSTANCE_PROMPT", "photo of a rhteddy dog")

MAX_TRAIN_STEPS = int(os.environ.get("MAX_TRAIN_STEPS", "400"))


os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(INSTANCE_DIR, exist_ok=True)

print(f"Weights will be saved at {OUTPUT_DIR}")
print(f"MODEL_NAME={MODEL_NAME}")
print(f"Training data located in {INSTANCE_DIR}")
print(f"Instance Prompt = {INSTANCE_PROMPT}")



## Training

### Get Training Data


In [None]:
import sys
import os
import tarfile
import urllib

url = INSTANCE_DATA_URL
output = f"instance-images.tar.gz"
urllib.request.urlretrieve(url, output)

!tar -xzf instance-images.tar.gz --strip-components=1 -C $INSTANCE_DIR

### Set up the Training Job

In [None]:
!accelerate config default

In [None]:
!wget https://raw.githubusercontent.com/huggingface/diffusers/main/examples/dreambooth/train_dreambooth.py -O train_dreambooth.py

In [None]:
!echo "MODEL_NAME=$MODEL_NAME"
!echo "OUTPUT_DIR=$OUTPUT_DIR"
!echo "DATA_DIR=$DATA_DIR"
!echo "INSTANCE_DIR=$INSTANCE_DIR"
!echo "CLASS_DATA_URL=$CLASS_DATA_URL"
!echo "CLASS_DIR=$CLASS_DIR"
!echo "INSTANCE_PROMPT=$INSTANCE_PROMPT"
!echo "CLASS_PROMPT=$CLASS_PROMPT"
!echo "NUM_CLASS_IMAGES=$NUM_CLASS_IMAGES"
!echo "MAX_TRAIN_STEPS=$MAX_TRAIN_STEPS"

In [None]:
!accelerate launch train_dreambooth.py \
  --pretrained_model_name_or_path=$MODEL_NAME  \
  --instance_data_dir=$INSTANCE_DIR \
  --output_dir=$OUTPUT_DIR \
  --instance_prompt="$INSTANCE_PROMPT" \
  --resolution=512 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=1 \
  --learning_rate=5e-6 \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --max_train_steps=$MAX_TRAIN_STEPS 

## Test Inference

In [None]:
import torch
from diffusers import DiffusionPipeline

pipeline = DiffusionPipeline.from_pretrained(OUTPUT_DIR)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pipeline.to(device)

In [None]:
from datetime import datetime
from torch import autocast

prompt = f"{INSTANCE_PROMPT} on the beach"
negative_prompt = ""
num_samples = 6
guidance_scale = 7.5
num_inference_steps = 200
height = 512
width = 512

generator = torch.Generator(device=device)
generator.seed()

with autocast(device.type), torch.inference_mode():
    images = pipeline(
        prompt,
        height=height,
        width=width,
        negative_prompt=negative_prompt,
        num_images_per_prompt=num_samples,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        generator=generator
    ).images


date = datetime.now()
date_string = date.strftime('%Y%m%d%H%M%S')
IMG_DIR = os.path.join(os.getcwd(), f"{VERSION}/generated-images/{date_string}")

os.makedirs(IMG_DIR, exist_ok=True)

for i, img in enumerate(images):
    img.save(os.path.join(IMG_DIR, f"{i}.jpg"))


In [None]:
import os
import math
import matplotlib.pyplot as plt
from PIL import Image

directory = IMG_DIR
images = []
for filename in os.listdir(directory):
    if filename.endswith('.jpg'):
        img = Image.open(os.path.join(directory, filename))
        images.append(img)

num_show = min(len(images), 12)
n_cols = 3
n_rows = math.ceil(len(images) / n_cols)
scale = 4
fig, axs = plt.subplots(n_rows, n_cols, figsize=(n_cols * scale, n_rows * scale),  gridspec_kw={'hspace': 0, 'wspace': 0})
for i in range(n_rows):
    for j in range(n_cols):
        axs[i, j].axis('off')
        x = i * n_cols + j
        if x < len(images):
            axs[i, j].imshow(images[x])
plt.show()
