論文<br>
https://arxiv.org/abs/2307.10159<br>
<br>
GitHub<br>
https://github.com/sd-fabric/fabric<br>
<br>
<a href="https://colab.research.google.com/github/kaz12tech/ai_demos/blob/master/Fabric_demo.ipynb" target="_blank"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# setup environment

## git clone

In [None]:
%cd /content

!git clone https://github.com/sd-fabric/fabric.git

%cd /content/fabric
# Commits on Jul 22, 2023
!git checkout 46787ad03716e310c7680174cdad9b0efbc393b5

## install libraries

In [None]:
%cd /content/fabric

!pip install -r requirements.txt
!pip install -e .

## import libraries

In [None]:
%cd /content/fabric

import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import clear_output
import functools

import torch

from fabric.generator import AttentionBasedGenerator
from fabric.iterative import IterativeFeedbackGenerator

# Load model

In [None]:
model_name = "runwayml/stable-diffusion-v1-5"

# download and load model
base_generator = AttentionBasedGenerator(
  model_name=model_name,
  model_ckpt=None,
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
)

# placed on GPU
if torch.cuda.is_available():
  base_generator.to("cuda")

# initialize generator
generator = IterativeFeedbackGenerator(base_generator)

# Setup config

## text prompt

In [None]:
prompt = "japan beautiful scenarios" #@param {type: "string"}
negative_prompt = "lower, low" #@param {type: "string"}

## other params

In [None]:
#`denoising_steps`: Number of steps in the denoising schedule
#`guidance_scale`: Strength of the classifier-free guidance (same as for any diffusion model)
#`feedback_start`: From which point in the diffusion process feedback should be added (0.0 -> from the beginning, 0.5 -> from the halfway point)
#`feedback_end`: Until which point feedback should be added (0.5 -> until the halfway point, 1.0 -> until the end)

denoising_steps = 20
guidance_scale = 6.0
feedback_start = 0.0
feedback_end = 0.5
seed = 12

# Define functions

In [None]:
# for showing created images
def display_images(images, n_cols=4, size=4):
  n_rows = int(np.ceil(len(images) / n_cols))
  fig = plt.figure(figsize=(size * n_cols, size * n_rows))
  for i, img in enumerate(images):
    ax = fig.add_subplot(n_rows, n_cols, i + 1)
    ax.imshow(img)
    ax.set_title(f"Image {i+1}")
    ax.axis("off")
  fig.tight_layout()
  return fig

In [None]:
# for feedback
def clicked_like(img, i, _):
  generator.give_feedback(liked=[img])
  text = widgets.Label(value=f"Added image {i+1} to liked images")
  display(text)

def clicked_dislike(img, i, _):
  generator.give_feedback(disliked=[img])
  text = widgets.Label(value=f"Added image {i+1} to disliked images")
  display(text)

# set like and dislike buttons
like_buttons = []
dislike_buttons = []
for i in range(4):
  like_button = widgets.Button(description=f"👍 Image {i+1}", button_style="success", tooltip="Add to liked images")
  like_buttons.append(like_button)

  dislike_button = widgets.Button(description=f"👎 Image {i+1}", button_style="danger", tooltip="Add to disliked images")
  dislike_buttons.append(dislike_button)

like_container = widgets.HBox(like_buttons)
dislike_container = widgets.HBox(dislike_buttons)

In [None]:
# for inference
def next_round(_):
  clear_output()
  images = generator.generate(
      prompt=prompt,
      negative_prompt=negative_prompt,
      denoising_steps=denoising_steps,
      guidance_scale=guidance_scale,
      feedback_start=feedback_start,
      feedback_end=feedback_end,
      seed=seed
  )
  clear_output()

  display_images(images)
  plt.show()

  for i in range(4):
    like_buttons[i]._click_handlers.callbacks = []
    dislike_buttons[i]._click_handlers.callbacks = []
    like_buttons[i].on_click(functools.partial(clicked_like, images[i], i))
    dislike_buttons[i].on_click(functools.partial(clicked_dislike, images[i], i))

  display(like_container)
  display(dislike_container)
  display(control_buttons)

def reset(_):
  generator.reset()
  text = widgets.Label(value="All feedback images have been cleared.")
  display(text)

# Inference

In [None]:
next_round_button = widgets.Button(description="Next Round", button_style="info")
next_round_button.on_click(next_round)
reset_button = widgets.Button(description="Reset Feedback", tooltip="Clear all feedback images")
reset_button.on_click(reset)
control_buttons = widgets.HBox([next_round_button, reset_button])

generator.reset()
next_round(None)