Skip to content

Commit

Permalink
img-2-img example (#440)
Browse files Browse the repository at this point in the history
* img-2-img example

* added dockerfile
  • Loading branch information
hirovi authored Mar 20, 2024
1 parent 74ff0f0 commit 3a1c358
Show file tree
Hide file tree
Showing 4 changed files with 313 additions and 0 deletions.
31 changes: 31 additions & 0 deletions examples/image-to-image/t2i-adapter-sketch/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# t2i-adapter-sketch

## Description

T2I Adapter is a network providing additional conditioning to stable diffusion. Each t2i checkpoint takes a different type of conditioning as input and is used with a specific base stable diffusion checkpoint.

![Screenshot 2024-03-20 at 14 11 16](https://github.com/mystic-ai/pipeline/assets/30600046/96ea83db-c5ef-4f1b-90e5-c6fb3ce97899)

### Local development

This Mystic pipeline uses a custom dockerfile. To either run or upload this pipeline, you can build the container using docker by running:

```sh
docker build -t sketch-2-img:latest -f t2i_adapter.dockerfile .
```

Then you can run it locally (assuming you have a GPU), by running:

```sh
docker run -p 14300:14300 --gpus all sketch-2-img:latest
```

If you head to `http://localhost:14300/play`, you will see an auto-generated UI to interact with the pipeline. Note, this pipeline requires aprox. 15GB of VRAM. A100-40GB is recommended.

### Upload

Assuming you have authenticated with Mystic and you have a valid api token, you can now upload your pipeline to your account by simply running,

```
pipeline container push
```
211 changes: 211 additions & 0 deletions examples/image-to-image/t2i-adapter-sketch/new_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
from io import BytesIO
from pathlib import Path
from typing import Optional

import numpy as np
import torch
import torchvision.transforms.functional as torch_transforms
from controlnet_aux.pidi import PidiNetDetector
from controlnet_aux.util import HWC3, resize_image
from diffusers import (
AutoencoderKL,
EulerAncestralDiscreteScheduler,
StableDiffusionXLAdapterPipeline,
T2IAdapter,
)
from PIL import Image

from pipeline import File, Pipeline, entity, pipe
from pipeline.objects.graph import InputField, InputSchema, Variable

style_list = [
{
"name": "(No style)",
"prompt": "{prompt}",
"negative_prompt": "",
},
{
"name": "Cinematic",
"prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
"negative_prompt": "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
},
{
"name": "3D Model",
"prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting",
"negative_prompt": "ugly, deformed, noisy, low poly, blurry, painting",
},
{
"name": "Anime",
"prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed",
"negative_prompt": "photo, deformed, black and white, realism, disfigured, low contrast",
},
{
"name": "Digital Art",
"prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed",
"negative_prompt": "photo, photorealistic, realism, ugly",
},
{
"name": "Photographic",
"prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed",
"negative_prompt": "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly",
},
{
"name": "Pixel art",
"prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics",
"negative_prompt": "sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic",
},
{
"name": "Fantasy art",
"prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy",
"negative_prompt": "photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, sloppy, duplicate, mutated, black and white",
},
{
"name": "Neonpunk",
"prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
"negative_prompt": "painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",
},
{
"name": "Manga",
"prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style",
"negative_prompt": "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style",
},
]

styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
STYLE_NAMES = list(styles.keys())
DEFAULT_STYLE_NAME = "(No style)"


class ModelKwargs(InputSchema):
steps: Optional[int] = InputField(
default=25,
title="Temperature",
description="Sampling temperature used for generation.",
)
guidance_scale: Optional[float] = InputField(
default=5.0,
title="Temperature",
description="Sampling temperature used for generation.",
)
adapter_conditioning_scale: Optional[float] = InputField(
default=0.8,
title="Temperature",
description="Sampling temperature used for generation.",
)
negative_prompt: Optional[str] = InputField(
default="",
title="Negative Prompt",
description="Provide what you want the model to avoid generating.",
optional=True,
)
style_name: Optional[str] = InputField(
default=DEFAULT_STYLE_NAME,
title="Style",
description="Select a style to generate in.",
choices=STYLE_NAMES,
)


@entity
class DiffusionWithAdapter:
def __init__(self) -> None:
...

def apply_style(
self, style_name: str, positive: str, negative: str = ""
) -> tuple[str, str]:
p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
return p.replace("{prompt}", positive), n + negative

@pipe(on_startup=True, run_once=True)
def load(self) -> None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

self.pidi = PidiNetDetector.from_pretrained("lllyasviel/Annotators").to(device)

# load adapter
adapter = T2IAdapter.from_pretrained(
"TencentARC/t2i-adapter-sketch-sdxl-1.0",
torch_dtype=torch.float16,
varient="fp16",
)

# load euler_a scheduler
model_id = "stabilityai/stable-diffusion-xl-base-1.0"

euler_a = EulerAncestralDiscreteScheduler.from_pretrained(
model_id, subfolder="scheduler"
)
vae = AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
)
self.pipe = StableDiffusionXLAdapterPipeline.from_pretrained(
model_id,
vae=vae,
adapter=adapter,
scheduler=euler_a,
torch_dtype=torch.float16,
variant="fp16",
).to(device)
self.pipe.enable_xformers_memory_efficient_attention()

@pipe
def inference(self, image: File, prompt: str, kwargs: ModelKwargs) -> list[File]:
input_image = image
image = Image.open(BytesIO(image.path.read_bytes())).convert("RGB")

img_resolution_target = 768
np_img = np.array(image)
img = resize_image(HWC3(np_img), img_resolution_target)
detected_map = np.zeros_like(img, dtype=np.uint8)
detected_map[np.min(img, axis=2) < 80] = 255

image = torch_transforms.to_tensor(detected_map) > 0.5
image = torch_transforms.to_pil_image(image.to(torch.float32))

prompt, negative_prompt = self.apply_style(
kwargs.style_name, prompt, kwargs.negative_prompt
)
image = self.pidi(
image, detect_resolution=1024, image_resolution=1024, apply_filter=True
)

gen_image = self.pipe(
prompt=prompt,
negative_prompt=negative_prompt,
image=image,
num_inference_steps=kwargs.steps,
adapter_conditioning_scale=kwargs.adapter_conditioning_scale,
guidance_scale=kwargs.guidance_scale,
).images[0]

print("Image generated")

path = Path("/tmp/dif_w_adapter/image.png")
path.parent.mkdir(parents=True, exist_ok=True)
gen_image.save(str(path))

output_image = File(
path=path, allow_out_of_context_creation=True
) # Return location of generated img
return [input_image, output_image]


with Pipeline() as builder:
image = Variable(
File,
title="Input sketch",
description="Upload a .png, .jpg or other image file of a sketch",
)
prompt = Variable(str, title="Prompt", description="Prompt to generate from")
kwargs = Variable(ModelKwargs)

model = DiffusionWithAdapter()
model.load()

# Forward pass
out = model.inference(image, prompt, kwargs)

builder.output(out)

my_new_pipeline = builder.get_pipeline()
15 changes: 15 additions & 0 deletions examples/image-to-image/t2i-adapter-sketch/pipeline.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
runtime:
container_commands:
- apt update -y
python:
version: '3.10'
requirements:
- pipeline-ai
accelerators: [nvidia_a100_20gb]
accelerator_memory: 15000
pipeline_graph: new_pipeline:my_new_pipeline
pipeline_name: sketch-2-img
description: null
readme: README.md
extras: {}
cluster: null
56 changes: 56 additions & 0 deletions examples/image-to-image/t2i-adapter-sketch/t2i_adapter.dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
FROM nvidia/cuda:12.1.1-cudnn8-devel-ubuntu20.04
ENV DEBIAN_FRONTEND=noninteractive

# python build dependencies
RUN apt update && \
apt install -y bash \
build-essential \
git \
git-lfs \
wget \
curl \
ca-certificates \
libssl-dev \
zlib1g-dev \
libbz2-dev \
libreadline-dev \
libsqlite3-dev \
libncursesw5-dev \
xz-utils \
tk-dev \
libxml2-dev \
libxmlsec1-dev \
libffi-dev \
liblzma-dev \
libgl1 && \
rm -rf /var/lib/apt/lists/*

WORKDIR /app

RUN apt-get update

# Install python
RUN git clone https://github.com/pyenv/pyenv.git /pyenv
ENV PYENV_ROOT=/pyenv
ENV PATH="/pyenv/shims:/pyenv/bin:$PATH"
RUN pyenv install 3.10.10
RUN pyenv global 3.10.10

# Install other python dependencies
RUN pip install setuptools wheel
RUN pip install pipeline-ai accelerate controlnet_aux diffusers
RUN pip install Pillow safetensors timm torch torchvision transformers opencv-python-headless
RUN pip install xformers --index-url https://download.pytorch.org/whl/cu121
RUN pip install fastapi==0.105.0 uvicorn==0.25.0 python-multipart==0.0.6 loguru==0.7.2

# Copy in files
COPY ./ ./

ENV NVIDIA_VISIBLE_DEVICES all
ENV NVIDIA_DRIVER_CAPABILITIES compute,utility
ENV PYTHONUNBUFFERED=1
ENV PIPELINE_PATH=new_pipeline:my_new_pipeline
ENV PIPELINE_NAME=sketch-2-img
ENV PIPELINE_IMAGE=sketch-2-img

CMD ["uvicorn", "pipeline.container.startup:create_app", "--host", "0.0.0.0", "--port", "14300"]

0 comments on commit 3a1c358

Please sign in to comment.