<a href="https://colab.research.google.com/github/calm-ixia/SDManualGUI/blob/main/PNGInfo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup

In [None]:
#@title ライブラリ導入

#version 7.7.1 < latest 8.0.4
from ipywidgets import widgets, Layout
from PIL import Image
from PIL.PngImagePlugin import PngInfo
import io

import os
import math
import time

In [None]:
#@title ユーティリティ

# 共通の処理

# バッチ1回で生成した個々の画像オブジェクトと、バッチ1回分のgrid画像オブジェクトを生成
def create_batch_image(outImages, width, height, grid_column_count):
  #gridを生成
  image_count = len(outImages)
  grid_whole_width = width * grid_column_count
  # バッチ1回ごとにグリッドの列数×行数の矩形を作る。余りが出る場合は黒塗りになる
  grid_whole_height = height * math.ceil(image_count/grid_column_count)
  grid_image = Image.new('RGB', size=(grid_whole_width, grid_whole_height) )

  #画像ファイルを保存
  # バッチ1回での行数
  batch_rows = 1 + ((image_count-1)//grid_column_count)
  for batch_index,image in enumerate(outImages):
    # gridの所定位置にペースト
    paste_col_index = batch_index % grid_column_count
    paste_row_index = batch_index//grid_column_count
    grid_image.paste(image, box=(width*paste_col_index, height*paste_row_index) )
  
  return grid_image

# 画像をファイルに保存
def save_image(image, fp, parameter_text):
  # 画像のタグ情報を生成
  grid_info = PngInfo()
  grid_info.add_text("parameters", parameter_text)
  # gridをファイルに保存
  image.save(fp, format="png", pnginfo=grid_info)

# サムネイルを作成、画像を縮小
def create_thumbnail(image, resize=0.125):
  return image.resize(size=(int(image.width*resize),int(image.height*resize)) )

"""
widgets.FileUpload から得られる値は、
{filename: {"metadata":{}, "content":rawdata} } の構造を持つdict
"""
def update_pnginfo(new_value):
  for filename, value in new_value.items():
    metadata = value["metadata"]

    #rawデータを取得
    data = value["content"]
    #画像情報取得用にPILのライブラリで画像を開く
    img = Image.open(io.BytesIO(data))
    #画像情報を取得
    parameter_text = img.info["parameters"]
    return (data, parameter_text)

def fileUploader_load_png(new_value):
  for filename, value in new_value.items():
    #print(value)
    #print(list(value.keys()))
    metadata = value["metadata"]
    #print(metadata)

    #rawデータを取得(bytesオブジェクト)
    data = value["content"]
    return data


In [None]:
#@title 画像生成モジュール用のデータコンテナ

import copy
from distutils.util import strtobool

class GeneratedFileInfo:
  def __init__(self, prompt="", negative_prompt="",
    scheduler_name="", step_count=0, guidance_scale=0, seed=0, width=0, height=0,
    batch_size=0, batch_count=0, grid_column_count=0, enable_safety_checker=True, eta=0.0, resize=0.25, latents_type=None, extra_param = {}
  ):
    self.path = ""
    self.thumbnailBytes: bytes = None
    self.imageBytes: bytes = None
    # set parameters
    self.prompt = prompt
    self.negative_prompt = negative_prompt
    self.scheduler_name = scheduler_name
    self.step_count = step_count
    self.guidance_scale = guidance_scale
    self.seed = seed
    self.width = width
    self.height = height
    self.batch_size = batch_size
    self.batch_count = batch_count
    self.grid_column_count = grid_column_count
    self.enable_safety_checker = enable_safety_checker
    self.eta = eta
    self.resize = resize
    self.latents_type = latents_type
    self.extra_param = extra_param

  def clone(self):
    cloned = self.clone_light()
    cloned.extra_param = copy.deepcopy(self.extra_param)
    return cloned

  # バイナリデータ(imageBytes, thumbnailBytes, extra_param['image'], extra_param['mask_image'], extra_param['latents'])はコピーしないので必要に応じてコピーすること
  def clone_light(self):
    cloned = GeneratedFileInfo()
    cloned.path = copy.deepcopy(self.path)
    # set parameters
    cloned.prompt = copy.deepcopy(self.prompt)
    cloned.negative_prompt = copy.deepcopy(self.negative_prompt)
    cloned.scheduler_name = copy.deepcopy(self.scheduler_name)
    cloned.step_count = copy.deepcopy(self.step_count)
    cloned.guidance_scale = copy.deepcopy(self.guidance_scale)
    cloned.seed = copy.deepcopy(self.seed)
    cloned.width = copy.deepcopy(self.width)
    cloned.height = copy.deepcopy(self.height)
    cloned.batch_size = copy.deepcopy(self.batch_size)
    cloned.batch_count = copy.deepcopy(self.batch_count)
    cloned.grid_column_count = copy.deepcopy(self.grid_column_count)
    cloned.enable_safety_checker = copy.deepcopy(self.enable_safety_checker)
    cloned.eta = copy.deepcopy(self.eta)
    cloned.resize = copy.deepcopy(self.resize)
    cloned.latents_type = copy.deepcopy(self.latents_type)
    # extra_paramのうちバイナリデータではないものを複製
    for key,val in self.extra_param.items():
      if key in ["image", "mask_image", "latents"]:
        continue
      cloned.extra_param[key] = copy.deepcopy(val)
    
    return cloned
    
  # 埋め込み用のパラメータ文字列を生成
  def create_parameter_text(self):
  #    f"Steps: {step_count}, Sampler: {scheduler_name}, CFG scale: {guidance_scale}, Seed: {seed}, Size: {width}x{height}, Model hash: {a2a802b2}",
    param = ", ".join([
      f"Steps: {self.step_count}",
      f"Sampler: {self.scheduler_name}",
      f"CFG scale: {self.guidance_scale}",
      f"Seed: {self.seed}",
      f"Size: {self.width}x{self.height}",
    ])
    if 'strength' in self.extra_param:
      param += f", Denoising strength: {self.extra_param['strength']}"
    if self.latents_type is not None:
      param += f", Masked content: {self.latents_type}"
    if 'mask_timing' in self.extra_param:
      param += f", Mask timing: {self.extra_param['mask_timing']}"
    if 'mask_output' in self.extra_param:
      param += f", Mask output: {self.extra_param['mask_output']}"

    return "\n".join([
      self.prompt,
      f"Negative prompt: {self.negative_prompt}",
      param
    ])

  @classmethod
  def parse_parameter_text(cls, parameter_text):
    param_dict = dict()

    negative_prompt_begin = parameter_text.find("Negative prompt: ")
    # "Negative prompt:" の直前までが通常のプロンプト
    prompt = parameter_text[0:negative_prompt_begin]
    param_dict['Prompt'] = prompt
    other_str = parameter_text[negative_prompt_begin:len(parameter_text)]

    # "Steps:" の直前までがネガティブプロンプト
    step_begin = parameter_text.find("Steps:")
    negative_prompt_pair = parameter_text[negative_prompt_begin:step_begin]
    negative_prompt = negative_prompt_pair.replace('Negative prompt: ', '', 1)
    param_dict['Negative prompt'] = negative_prompt

    # "foo: **, " の形式で設定パラメータが続く
    param_dict_str = parameter_text[step_begin:len(parameter_text)]
    param_pairs = param_dict_str.split(',')
    for pair in param_pairs:
      kv = pair.split(':')
      if len(kv) == 2:
        key,value = kv
        # ,:の区切りで分割するだけだと先頭や末尾に空白が残るので、それらを除去
        # 'Euler a' のように語中に空白があるケースもあるので、replace()ではなくstrip()を使用
        param_key = key.strip()
        param_value = value.strip()
        param_dict[param_key] =param_value
      else:
        print(f"Warning: skip key-value {kv}")
      
    # 必要に応じて数値文字列を数値に変換
    conv = {
      'Prompt': lambda x: x,
      'Negative prompt': lambda x: x,
      'Steps': lambda x: int(x),
      'Sampler': lambda x: x, # 文字列をそのまま取得
      'CFG scale': lambda x: float(x),
      'Seed': lambda x: int(x),
      # Sizeをタプルに変換。 "Size: 512x512" のように記述されている
      'Size': lambda x: tuple([int(s) for s in x.split('x')]),
      'Denoising strength': lambda x: float(x),
      'Masked content': lambda x: x,
      'Mask timing': lambda x: x,
      'Mask output': lambda x: bool(strtobool(x)),
    }

    result = dict()
    for key,val in param_dict.items():
      if key in conv:
        result[key] = conv[key](val)
    
    return result

  @classmethod
  def from_parameter_text(cls, parameter_text):
    param_dict = cls.parse_parameter_text(parameter_text)
    latents_type = None
    extra_param = {}
    if 'Denoising strength' in param_dict:
      extra_param['strength'] = param_dict['Denoising strength']
    if 'Masked content' in param_dict:
      latents_type = param_dict['Masked content']
    if 'Mask timing' in param_dict:
      extra_param['mask_timing'] = param_dict['Mask timing']
    if 'Mask output' in param_dict:
      extra_param['mask_output'] = param_dict['Mask output']

    return GeneratedFileInfo(
      prompt = param_dict['Prompt'],
      negative_prompt = param_dict['Negative prompt'],
      scheduler_name = param_dict['Sampler'],
      step_count = param_dict['Steps'],
      guidance_scale = param_dict['CFG scale'],
      seed = param_dict['Seed'],
      width = param_dict['Size'][0],
      height = param_dict['Size'][1],
      #batch_size = param_dict[''],
      #batch_count = param_dict[''],
      #grid_column_count = param_dict[''],
      #enable_safety_checker = param_dict[''],
      #eta = param_dict[''],
      latents_type = latents_type,
      extra_param = extra_param,
    )

  def save_image(self, image, filepath):
    parameter_text = self.create_parameter_text()
    save_image(image, filepath, parameter_text)
    with io.BytesIO() as buf:
      save_image(image, buf, parameter_text)
      self.imageBytes = buf.getvalue()

  def save_thumbnail(self, image, resize):
    parameter_text = self.create_parameter_text()
    thumbnail = image.resize(size=(int(image.width*resize),int(image.height*resize)) )
    with io.BytesIO() as buf:
      save_image(thumbnail, buf, parameter_text)
      self.thumbnailBytes = buf.getvalue()



In [None]:
#@title PNGInfoペイン

# 実際のレイアウトは別途行う
class PnginfoPaneView:
  def __init__(self, txt2imgPane, img2imgPane, inpaintPane):
    self.txt2imgPane = txt2imgPane
    self.img2imgPane = img2imgPane
    self.inpaintPane = inpaintPane
    self.param_dict = {}
    self.info = None #GeneratedFileInfo

    self.pnginfo_prompt_textarea = widgets.Textarea(placeholder="prompt", disabled=True, layout = Layout(width="100%",height="15em"))
    
    self.pnginfo_upload_button = widgets.FileUpload(accept=".png; .jpg", multiple=False, description="upload an image")
    self.pnginfo_upload_button.observe(self.on_pnginfo_upload, names='value') #FileUploadはon_clickを持たないため、valueを監視してイベントハンドラを起動
    
    self.pnginfo_upload_image = widgets.Image(width="512", height="512")
    
    self.send_to_txt2img_button = widgets.Button(description="send to txt2img",layout = Layout(width="auto") )
    self.send_to_txt2img_button.on_click(self.on_send_to_txt2img)

    self.send_to_img2img_button = widgets.Button(description="send to img2img",layout = Layout(width="auto") )
    self.send_to_img2img_button.on_click(self.on_send_to_img2img)

    self.send_to_inpaint_button = widgets.Button(description="send to inpaint",layout = Layout(width="auto") )
    self.send_to_inpaint_button.on_click(self.on_send_to_inpaint)

    self.thumbnail_image = widgets.Image(width="64", height="64")

    self.enable_edit_check = widgets.Checkbox(value=False, description="enable edit")
    self.enable_edit_check.observe(self.on_enable_edit_check, names="value")

    self.edit_save_button = widgets.Button(description="save image",layout = Layout(width="auto") )
    self.edit_save_button.on_click(self.on_edit_save)
    self.edit_save_thumbnail_button = widgets.Button(description="save thumbnail",layout = Layout(width="auto") )
    self.edit_save_thumbnail_button.on_click(self.on_edit_save_thumbnail)
    
  def set_layout(self):
    # PNG Infoのペイン
    self.layout = \
    widgets.HBox(layout=Layout(width="100%", height="100%"), children=[
      # 画像のアップロード
      widgets.VBox(layout=Layout(width="540", height="100%"), children=[
        self.pnginfo_upload_button,
         widgets.HBox([self.pnginfo_upload_image], layout=Layout(width="520", height="520")),
      ] ),
      # 画像の設定パラメータ表示、SendToボタン
      widgets.VBox(layout=Layout(width="60%", height="100%"), children=[
        widgets.HBox([self.pnginfo_prompt_textarea], layout=Layout(width="100%", height="100%")),
        widgets.HBox( [
          #self.send_to_txt2img_button,
          #self.send_to_img2img_button,
          #self.send_to_inpaint_button,
          widgets.VBox([
            self.enable_edit_check,
            widgets.HBox( [self.edit_save_button, self.edit_save_thumbnail_button,]),
          ]),
        ]),
        widgets.HBox([
          self.thumbnail_image,
        ]),
      ]),
    ])
    return self.layout

  #ファイルアップロード完了時に画像とパラメータ表示を更新する
  def on_pnginfo_upload(self, change):
    (data, parameter_text) = update_pnginfo(new_value=change.new)
    image = Image.open(io.BytesIO(data))
    self.info = GeneratedFileInfo.from_parameter_text(parameter_text)
    self.info.imageBytes = data
    self.pnginfo_upload_image.value = data
    self.pnginfo_upload_image.width = min(512, image.width)
    self.pnginfo_upload_image.height = min(512, image.height)
    self.pnginfo_prompt_textarea.value = parameter_text
    
    # サムネイルを作成、parameter_textをそのまま保存
    resize = 0.25
    thumbnail = create_thumbnail(image=image, resize=resize)
    with io.BytesIO() as buf:
      save_image(thumbnail, buf, parameter_text)
      self.thumbnail_image.width = min(512, image.width*resize)
      self.thumbnail_image.height = min(512, image.height*resize)
      self.thumbnail_image.value = buf.getvalue()
    
    # ロード直後は常にedit無効
    self.enable_edit_check.value = False

  """
  "send to **"ボタン押下時にそれぞれのペインへ反映させる
  """
  def on_send_to_txt2img(self, remove):
    # TODO: should be event-driven
    new_info = GeneratedFileInfo.from_parameter_text(self.pnginfo_prompt_textarea.value)
    new_info.imageBytes = self.pnginfo_upload_image.value
    self.info = new_info
    txt2imgPane.reflectPNGInfo(self.info)

  def on_send_to_img2img(self, remove):
    # TODO: should be event-driven
    new_info = GeneratedFileInfo.from_parameter_text(self.pnginfo_prompt_textarea.value)
    new_info.imageBytes = self.pnginfo_upload_image.value
    self.info = new_info
    img2imgPane.reflectPNGInfo(self.info)

  def on_send_to_inpaint(self, remove):
    # TODO: should be event-driven
    new_info = GeneratedFileInfo.from_parameter_text(self.pnginfo_prompt_textarea.value)
    new_info.imageBytes = self.pnginfo_upload_image.value
    self.info = new_info
    inpaintPane.reflectPNGInfo(self.info)

  # check時に編集の有効を切り替える
  def on_enable_edit_check(self, change):
    self.pnginfo_prompt_textarea.disabled = not change.new

  def on_edit_save(self, remove):
    img = Image.open(io.BytesIO(self.pnginfo_upload_image.value))
    with io.BytesIO() as buf:
      save_image(img, buf, self.pnginfo_prompt_textarea.value)
      self.pnginfo_upload_image.value = buf.getvalue()

  def on_edit_save_thumbnail(self, remove):
    img = Image.open(io.BytesIO(self.thumbnail_image.value))
    with io.BytesIO() as buf:
      save_image(img, buf, self.pnginfo_prompt_textarea.value)
      self.thumbnail_image.value = buf.getvalue()

#PNGInfoペインのインスタンスを生成
pnginfoPane = PnginfoPaneView(None, None, None)


In [None]:
#@title 全体レイアウト

# トップメニュータブと各ペインを設定
top_tab_dict = {"PNG Info":pnginfoPane.set_layout() }

top_tab = widgets.Tab(layout = Layout(width="100%",height="100%"))
top_tab.children = list(top_tab_dict.values())
for i,title in enumerate(top_tab_dict.keys()):
  top_tab.set_title(i, title)

# メインウィンドウを生成
main_window = widgets.Box([top_tab], layout = Layout(width="100%",height="100%"))

# Launching

In [None]:
#@title GUIを起動
main_window

# このノートブックについて
SDManualGUI からPNG Info機能だけ切り出した。Colab内ローカル（CPUのみ、Gradioなし）で完結