# Sketch to Image Application

Colab 환경에서 스케치 투 이미지 애플리케이션을 만들어봅시다.


## 패키지 및 예제 데이터 다운로드하기
python package들을 설치합니다. Colab에서 실행하지 않는 경우 이 셀은 실행하지 않습니다.

In [None]:
# !wget https://raw.githubusercontent.com/mrsyee/dl_apps/main/image_generation/requirements-colab.txt
# !pip install -r requirements-colab.txt

In [1]:
!pip install diffusers



## 패키지 불러오기

In [2]:
import os
from typing import IO

import gradio as gr
import requests
import torch
from tqdm import tqdm
from diffusers import StableDiffusionImg2ImgPipeline
from PIL import Image

  from .autonotebook import tqdm as notebook_tqdm
2024-09-25 08:43:09.516159: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-09-25 08:43:09.685883: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-09-25 08:43:09.734701: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-09-25 08:43:10.075466: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


## 스케치 투 이미지 생성 UI 구현하기

In [None]:
WIDTH = 512
HEIGHT = 512

with gr.Blocks() as app:
    gr.Markdown("## 프롬프트 입력")
    with gr.Row():
        prompt = gr.Textbox(label="Prompt")
    with gr.Row():
        n_prompt = gr.Textbox(label="Negative Prompt")

    gr.Markdown("## 스케치 to 이미지 생성")
    with gr.Row():
        with gr.Column():
            with gr.Tab("Canvas"):
                with gr.Row():
                    canvas = gr.Image(
                        label="Draw",
                        source="canvas",
                        image_mode="RGB",
                        tool="color-sketch",
                        interactive=True,
                        width=WIDTH,
                        height=HEIGHT,
                        shape=(WIDTH, HEIGHT),
                        brush_radius=20,
                        type="pil",
                    )
                with gr.Row():
                    canvas_run_btn = gr.Button(value="Generate")

            with gr.Tab("File"):
                with gr.Row():
                    file = gr.Image(
                        label="Upload",
                        source="upload",
                        image_mode="RGB",
                        tool="color-sketch",
                        interactive=True,
                        width=WIDTH,
                        height=HEIGHT,
                        shape=(WIDTH, HEIGHT),
                        type="pil",
                    )
                with gr.Row():
                    file_run_btn = gr.Button(value="Generate")

        with gr.Column():
            result_gallery = gr.Gallery(label="Output", height=512)

In [None]:
app.launch(inline=False, share=True)

In [None]:
app.close()

## 모델 다운로드 UI 구현하기

In [None]:
with gr.Blocks() as app:
    gr.Markdown("## 모델 다운로드")
    with gr.Row():
        model_url = gr.Textbox(label="모델 URL", placeholder="https://civitai.com/")
        download_model_btn = gr.Button(value="모델 다운로드")
    with gr.Row():
        model_file = gr.File(label="모델 파일")

In [None]:
app.launch(inline=False, share=True)

In [None]:
app.close()

## 모델 다운로드 기능 구현하기

In [None]:
import os
import glob

# 전역 변수로 모델 경로와 파일명을 저장
MODEL_PATH = None

# 모델을 다운로드하고 경로를 기억하는 함수
def download_model(url: str) -> str:
    global MODEL_PATH  # 전역 변수를 사용해서 경로를 기억

    model_id = url.replace("https://civitai.com/models/", "").split("/")[0]

    try:
        response = requests.get(f"https://civitai.com/api/v1/models/{model_id}", timeout=600)
    except Exception as err:
        print(f"[ERROR] {err}")
        raise err

    download_url = response.json()["modelVersions"][0]["downloadUrl"]
    filename = response.json()["modelVersions"][0]["files"][0]["name"]

    file_path = f"models/{filename}"
    if os.path.exists(file_path):
        print(f"[INFO] File already exists: {file_path}")
        MODEL_PATH = file_path  # 모델 경로 기억
        return file_path

    os.makedirs("models", exist_ok=True)
    download_from_url(download_url, file_path)
    print(f"[INFO] File downloaded: {file_path}")
    
    # 모델 경로 기억
    MODEL_PATH = file_path
    return file_path

# ./models 폴더에서 가장 최근에 수정된 모델 파일 찾기
def find_latest_model_in_directory(directory: str) -> str:
    model_files = glob.glob(f"{directory}/*.safetensors")
    if not model_files:
        return None

     # 가장 최근에 수정된 모델 파일 선택
    latest_model = max(model_files, key=os.path.getmtime)
    return latest_model

In [None]:
with gr.Blocks() as app:
    gr.Markdown("## 모델 다운로드")
    with gr.Row():
        model_url = gr.Textbox(label="모델 URL", placeholder="https://civitai.com/")
        download_model_btn = gr.Button(value="모델 다운로드")
    with gr.Row():
        model_file = gr.File(label="모델 파일")

In [None]:
app.queue().launch(inline=False, share=True)

In [None]:
app.close()

## 모델 불러오기 UI 및 기능 구현하기

In [None]:
# 다운로드된 모델을 불러오는 함수
def init_pipeline() -> str:
    global MODEL_PATH  # 전역 변수를 사용

    if MODEL_PATH is None:
        # MODEL_PATH가 없으면 ./models에서 모델을 찾음
        print("[INFO] No model path found, searching ./models directory...")
        MODEL_PATH = find_latest_model_in_directory("./models")
    
    if MODEL_PATH is None:
        return "Error: No model found in ./models directory"

    print(f"[INFO] Initialize pipeline with model: {MODEL_PATH}")
    global PIPELINE

    try:
        PIPELINE = StableDiffusionImg2ImgPipeline.from_single_file(
            MODEL_PATH,
            torch_dtype=torch.float16,
            variant="fp16",
            use_safetensors=True,
        ).to("cuda")
        print("[INFO] Initialized pipeline")
        return "Model Loaded!"
    except Exception as e:
        print(f"[ERROR] Failed to load model: {e}")
        return f"Error: {e}"

In [None]:
# Gradio 인터페이스 설정
with gr.Blocks() as app:
    gr.Markdown("## 모델 불러오기")
    with gr.Row():
        load_model_btn = gr.Button(value="모델 불러오기")
    with gr.Row():
        is_model_check = gr.Textbox(label="Model Load Check", value="Model Not Loaded")

    download_model_btn.click(
        download_model,
        [model_url],
        [model_file],
    )
    load_model_btn.click(
        init_pipeline,
        None,  # 모델을 불러올 때는 별도의 입력이 필요하지 않음
        [is_model_check],
    )

In [None]:
app.queue().launch(inline=False, share=True)

In [None]:
app.close()

## 스케치 투 이미지 생성 기능 구현하기

In [None]:
def sketch_to_image(sketch: Image.Image, prompt: str, negative_prompt: str):
    width, height = sketch.size
    images =  PIPELINE(
        image=sketch,
        prompt=prompt,
        negative_prompt=negative_prompt,
        height=height,
        width=width,
        num_images_per_prompt=4,
        num_inference_steps=20,
        strength=0.7,
    ).images

    with torch.cuda.device("cuda"):
        torch.cuda.empty_cache()

    return images

In [None]:
print("[INFO] Gradio app ready")
with gr.Blocks() as app:
    gr.Markdown("# 스케치 to 이미지 애플리케이션")

    gr.Markdown("## 모델 다운로드")
    with gr.Row():
        model_url = gr.Textbox(label="Model Link", placeholder="https://civitai.com/")
        download_model_btn = gr.Button(value="Download model")
    with gr.Row():
        model_file = gr.File(label="Model File")

    gr.Markdown("## 모델 불러오기")
    with gr.Row():
        load_model_btn = gr.Button(value="Load model")
    with gr.Row():
        is_model_check = gr.Textbox(label="Model Load Check", value="Model Not loaded")

    gr.Markdown("## 프롬프트 입력")
    with gr.Row():
        prompt = gr.Textbox(label="Prompt")
    with gr.Row():
        n_prompt = gr.Textbox(label="Negative Prompt")

    gr.Markdown("## 스케치 to 이미지 생성")
    with gr.Row():
        with gr.Column():
            with gr.Tab("Canvas"):
                with gr.Row():
                    canvas = gr.Image(
                        label="Draw",
                        source="canvas",
                        image_mode="RGB",
                        tool="color-sketch",
                        interactive=True,
                        width=WIDTH,
                        height=HEIGHT,
                        shape=(WIDTH, HEIGHT),
                        brush_radius=20,
                        type="pil",
                    )
                with gr.Row():
                    canvas_run_btn = gr.Button(value="Generate")

            with gr.Tab("File"):
                with gr.Row():
                    file = gr.Image(
                        label="Upload",
                        source="upload",
                        image_mode="RGB",
                        tool="color-sketch",
                        interactive=True,
                        width=WIDTH,
                        height=HEIGHT,
                        shape=(WIDTH, HEIGHT),
                        type="pil",
                    )
                with gr.Row():
                    file_run_btn = gr.Button(value="Generate")

        with gr.Column():
            result_gallery = gr.Gallery(label="Output", height=512)


    # Event
    download_model_btn.click(
        download_model,
        [model_url],
        [model_file],
    )
    load_model_btn.click(
        init_pipeline,
        [model_file],
        [is_model_check],
    )
    canvas_run_btn.click(
        sketch_to_image,
        [canvas, prompt, n_prompt],
        [result_gallery],
    )
    file_run_btn.click(
        sketch_to_image,
        [file, prompt, n_prompt],
        [result_gallery],
    )

In [None]:
app.queue().launch(inline=False, share=True)

In [None]:
app.close()

## 최종 App 구현

In [7]:
import os
from typing import IO
import glob
import gradio as gr
import requests
import torch
import tempfile
import torch_directml
from tqdm import tqdm
from diffusers import StableDiffusionImg2ImgPipeline
from PIL import Image

In [4]:
WIDTH = 512
HEIGHT = 512

# 전역 변수를 사용하여 모델 경로와 파이프라인 객체를 저장
MODEL_PATH = None
PIPELINE = None

In [5]:
# ./models 디렉토리에서 가장 최근에 수정된 모델 파일 찾기
def find_latest_model_in_directory(directory: str) -> str:
    model_files = glob.glob(f"{directory}/*.safetensors")
    if not model_files:
        return None
    
    # 가장 최근에 수정된 모델 파일 선택
    latest_model = max(model_files, key=os.path.getmtime)
    return latest_model

# 모델 다운로드 함수
def download_model(url: str) -> str:
    model_id = url.replace("https://civitai.com/models/", "").split("/")[0]

    try:
        response = requests.get(f"https://civitai.com/api/v1/models/{model_id}", timeout=600)
        response.raise_for_status()  # 요청 상태 확인
    except Exception as err:
        print(f"[ERROR] {err}")
        raise err

    # 모델 다운로드 URL 및 파일명 추출
    download_url = response.json()["modelVersions"][0]["downloadUrl"]
    filename = response.json()["modelVersions"][0]["files"][0]["name"]

    file_path = f"models/{filename}"
    
    # 이미 다운로드된 파일이 존재하는 경우
    if os.path.exists(file_path):
        print(f"[INFO] File already exists: {file_path}")
        return file_path

    # 모델 저장 디렉토리 생성
    os.makedirs("models", exist_ok=True)
    
    # 모델 다운로드
    download_from_url(download_url, file_path)
    print(f"[INFO] File downloaded: {file_path}")
    return file_path


# URL로부터 파일 다운로드 함수
def download_from_url(url: str, file_path: str, chunk_size=1024):
    try:
        resp = requests.get(url, stream=True)
        resp.raise_for_status()  # 다운로드 요청 상태 확인
    except Exception as err:
        print(f"[ERROR] {err}")
        raise err

    total = int(resp.headers.get('content-length', 0))  # 파일 크기 추출
    with open(file_path, 'wb') as file, tqdm(
        desc=file_path,
        total=total,
        unit='iB',
        unit_scale=True,
        unit_divisor=1024,
    ) as bar:
        for data in resp.iter_content(chunk_size=chunk_size):
            size = file.write(data)
            bar.update(size)

# 모델 파이프라인 초기화 함수
def init_pipeline() -> str:
    global MODEL_PATH  # 전역 변수를 사용

    if MODEL_PATH is None:
        # MODEL_PATH가 없으면 ./models에서 모델을 찾음
        print("[INFO] No model path found, searching ./models directory...")
        MODEL_PATH = find_latest_model_in_directory("./models")
    
    if MODEL_PATH is None:
        return "Error: No model found in ./models directory"

    print(f"[INFO] Initialize pipeline with model: {MODEL_PATH}")
    global PIPELINE

    try:
        PIPELINE = StableDiffusionImg2ImgPipeline.from_single_file(
            MODEL_PATH,
            torch_dtype=torch.float16,
            variant="fp16",
            use_safetensors=True,
        ).to("cuda")
        print("[INFO] Initialized pipeline")
        return "Model Loaded!"
    except Exception as e:
        print(f"[ERROR] Failed to load model: {e}")
        return f"Error: {e}"

# 스케치에서 이미지를 생성하는 함수
from typing import List

def sketch_to_image(sketch: Image.Image, prompt: List[str], negative_prompt: List[str]):
    global PIPELINE
    if PIPELINE is None:
        return "[ERROR] Pipeline is not initialized."

    # 프롬프트와 네거티브 프롬프트를 리스트로 변환
    if isinstance(prompt, str):
        prompt = [prompt]  # 문자열을 리스트로 변환
    if isinstance(negative_prompt, str):
        negative_prompt = [negative_prompt]  # 문자열을 리스트로 변환

    # 프롬프트와 네거티브 프롬프트의 개수를 일치시킴
    if len(prompt) != len(negative_prompt):
        if len(prompt) > len(negative_prompt):
            negative_prompt += [""] * (len(prompt) - len(negative_prompt))
        else:
            prompt += [""] * (len(negative_prompt) - len(prompt))

    width, height = sketch.size

    # 프롬프트 수에 맞게 이미지를 복제
    images = [sketch] * len(prompt)

    print(f"[INFO] Generating image with dimensions: {width}x{height}")

    try:
        # 이미지 생성
        result = PIPELINE(
            image=images,
            prompt=prompt,
            negative_prompt=negative_prompt,
            height=height,
            width=width,
            num_images_per_prompt=2,
            num_inference_steps=10,
            strength=0.7,
        ).images
    except Exception as e:
        print(f"[ERROR] Failed to generate image: {e}")
        return f"Error: {e}"

    # GPU 메모리 캐시 비우기
    with torch.cuda.device("cuda"):
        torch.cuda.empty_cache()

    return result

In [6]:
print("[INFO] Gradio app ready")

with gr.Blocks() as app:
    gr.Markdown("# 스케치 to 이미지 애플리케이션")

    # 모델 다운로드 섹션
    gr.Markdown("## 모델 다운로드")
    with gr.Row():
        model_url = gr.Textbox(label="Model Link", placeholder="https://civitai.com/")
        download_model_btn = gr.Button(value="Download model")
    with gr.Row():
        download_status = gr.Textbox(label="Download Status", value="Not downloaded yet")

    # 모델 불러오기 섹션
    gr.Markdown("## 모델 불러오기")
    with gr.Row():
        load_model_btn = gr.Button(value="Load model")
    with gr.Row():
        is_model_check = gr.Textbox(label="Model Load Check", value="Model Not loaded")

    # 프롬프트 입력 섹션
    gr.Markdown("## 프롬프트 입력")
    with gr.Row():
        prompt = gr.Textbox(label="Prompt")
    with gr.Row():
        n_prompt = gr.Textbox(label="Negative Prompt")

    # 스케치 to 이미지 생성 섹션
    gr.Markdown("## 스케치 to 이미지 생성")
    with gr.Row():
        with gr.Column():
            with gr.Tab("Canvas"):
                with gr.Row():
                    canvas = gr.Image(
                        label="Draw",
                        source="canvas",
                        image_mode="RGB",
                        tool="color-sketch",
                        interactive=True,
                        width=WIDTH,
                        height=HEIGHT,
                        shape=(WIDTH, HEIGHT),
                        brush_radius=20,
                        type="pil",
                    )
                with gr.Row():
                    canvas_run_btn = gr.Button(value="Generate from Canvas")

            with gr.Tab("File"):
                with gr.Row():
                    file = gr.Image(
                        label="Upload",
                        source="upload",
                        image_mode="RGB",
                        interactive=True,
                        width=WIDTH,
                        height=HEIGHT,
                        type="pil",
                    )
                with gr.Row():
                    file_run_btn = gr.Button(value="Generate from File")

        # 결과 이미지 갤러리
        with gr.Column():
            result_gallery = gr.Gallery(label="Output", height=512)

    # Event 핸들러 설정
    # 모델 다운로드 버튼 클릭 이벤트
    download_model_btn.click(
        download_model,  # 모델 다운로드 함수 호출
        [model_url],  # 입력으로 model_url 사용
        [download_status],  # 다운로드 상태를 download_status 텍스트박스에 표시
    )

    # 모델 불러오기 버튼 클릭 이벤트
    load_model_btn.click(
        init_pipeline,  # 모델 로드 함수 호출
        [],  # 파일 경로는 함수 내에서 관리하므로 입력 없음
        [is_model_check],  # 모델 로드 상태를 is_model_check 텍스트박스에 표시
    )

    # Canvas에서 이미지 생성 버튼 클릭 이벤트
    canvas_run_btn.click(
        sketch_to_image,  # 스케치에서 이미지 생성 함수 호출
        [canvas, prompt, n_prompt],  # 입력으로 canvas, prompt, negative prompt 사용
        [result_gallery],  # 출력 이미지가 result_gallery에 표시
    )

    # File 업로드에서 이미지 생성 버튼 클릭 이벤트
    file_run_btn.click(
        sketch_to_image,  # 업로드된 이미지에서 생성 함수 호출
        [file, prompt, n_prompt],  # 입력으로 업로드된 파일, prompt, negative prompt 사용
        [result_gallery],  # 출력 이미지가 result_gallery에 표시
    )

# Gradio 애플리케이션 실행
app.queue().launch(inline=False, share=True)

[INFO] Gradio app ready
Running on local URL:  http://127.0.0.1:7860
IMPORTANT: You are using gradio version 3.40.0, however version 4.29.0 is available, please upgrade.
--------
Running on public URL: https://f3a0bdb4a146fae697.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)




[INFO] No model path found, searching ./models directory...
[INFO] Initialize pipeline with model: ./models/disneyPixarCartoon_v10.safetensors


Fetching 11 files: 100%|██████████████████████████████████████| 11/11 [00:01<00:00,  7.84it/s]
Some weights of the model checkpoint were not used when initializing CLIPTextModel: 
 ['text_model.embeddings.position_ids']
Loading pipeline components...: 100%|███████████████████████████| 6/6 [01:16<00:00, 12.77s/it]
You have disabled the safety checker for <class 'diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .


[ERROR] Failed to load model: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx


In [None]:
app.close()