In [None]:
from google.colab import drive
drive.mount('/content/drive')
#@title Download Requirements
print('🛠️ Installing TorchSDE, Compel (tokenization), and REALESRGAN...')
!pip install torchsde > /dev/null
!pip install compel > /dev/null
!pip install realesrgan > /dev/null
!pip install basicsr-fixed > /dev/null


print('🛠️ Installing GFPGAN...')
%cd /content
# !if [ -e /content/GFPGAN ]; then rm -r /content/GFPGAN; fi
!git clone https://github.com/TencentARC/GFPGAN.git > /dev/null
!pip install facexlib > /dev/null

%cd /content/GFPGAN
!pip install -r requirements.txt > /dev/null
!python3 setup.py develop > /dev/null 2>&1
!wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth -P experiments/pretrained_models > /dev/null

!mkdir /content/upscale
!mkdir /content/upscaled

print('🛠️ Installing Toolkit...')

In [None]:
#@title Git Pull
# !if [ -e /content/toolkit ]; then rm -r /content/toolkit; fi
!git clone https://github.com/mothballs-x/SDXL-Toolkit.git /content/toolkit > /dev/null

# Comment out these lines for more reliable txt2img generation
%cd /content/toolkit
!git checkout img2img

import ipywidgets as widgets
from IPython.display import display, clear_output

with open('/content/toolkit/.version') as file:
  version = file.read()

print(f'🛠️ Toolkit Version: {version}')
if 'img2img' in version:
  from diffusers import StableDiffusionXLImg2ImgPipeline
else:
  from diffusers import StableDiffusionXLPipeline
from diffusers import AutoencoderKL

from transformers import CLIPTextModel, CLIPTokenizer # For token length check, not implemented here
from compel import Compel, ReturnedEmbeddingsType
from IPython.display import display
import ipywidgets as widgets
from pathlib import Path
import subprocess
import torch
import sys
import re
import os
import time

In [None]:
# @title Setup:
tools = Path('/content/toolkit/')
sys.path.append(str(tools))
sys.path.append(str(tools / 'utilities'))

from utilities.managers import PromptManager, LoraManager, ImageGenerator
from utilities.download import download_file
from utilities.img_util import ImageUtility
from utilities.gridview import ImageGrid
from utilities.embedding import create_embedding
from utilities.logger import log_generation as logger
from utilities.upscaler import factorize, upscale
import utilities.schedulers as schedulers


emojis = {'success': '✅', 'warning': '⚠️', 'error': '❌',
          'info': '🔍', 'loading': '⏳', 'bullet': '🔹'}


# Widgets
DEVICE = widgets.Dropdown(
    options=['cuda', 'cpu'],
    description='Device:',
    disabled=False,
    layout=widgets.Layout(width='400px')
    )

DTYPE = torch.float16 if DEVICE.value == 'cuda' else torch.float32

ROOT_DIR_TEXT = widgets.Text(value='/content/drive/My Drive/generated',
                        placeholder='Path for Image Output: ',
                        layout=widgets.Layout(width='400px'),
                        style={'description_width': 'auto'},
                        description='Root Directory: '
                        )

CIVITAI_TOKEN = widgets.Text(value="",
                             placeholder='Enter API Token',
                             description='Civitai Token: ',
                             layout=widgets.Layout(width='400px')
                             )
LOG_PATH = widgets.Text(value='/content/drive/My Drive/generated/gen_log.json',
                        placeholder='Absolute path of JSON log file',
                        layout=widgets.Layout(width='400px'),
                        style={'description_width': 'auto'},
                        description='Log File: '
                        )

set_button = widgets.Button(description='Set Parameters', button_style='success')

setup_output = widgets.Output()

def check_cuda(change):
  if not torch.cuda.is_available() and DEVICE.value == 'cuda':
    DEVICE.value = 'cpu'
    with setup_output:
      clear_output(wait=True)
      print(f'{emojis["warning"]} Cuda no available, using CPU')
  else:
    return

DEVICE.observe(check_cuda, names='value',
               )

resources = Path('/content/toolkit/resources')
grid_view = None
img_util = None


def set_up(b):

  global grid_view
  global img_util

  if not Path(LOG_PATH.value).exists():
    Path(LOG_PATH.value).touch()

  ROOT_DIR = Path(ROOT_DIR_TEXT.value)
  if not ROOT_DIR.exists():
    with setup_output:
      clear_output(wait=True)
      print(f'{emojis["error"]} Path {ROOT_DIR} does not exist')
    return

  current_count = len(list(ROOT_DIR.glob("*.png")))
  img_util = ImageUtility(ROOT_DIR, count=current_count + 1)
  grid_view = ImageGrid()
  with setup_output:
    clear_output(wait=True)
    print(f'{emojis["success"]} Setup complete')
    time.sleep(2.5)
    print(f'{emojis["info"]} Current Image Count in {ROOT_DIR}: {current_count}')


set_button.on_click(set_up)

display(DEVICE, CIVITAI_TOKEN, ROOT_DIR_TEXT, LOG_PATH, set_button, setup_output)

In [None]:
#@title Models

import json
import subprocess

!if [ ! -e "/content/models" ]; then mkdir "/content/models"; fi
model_json = resources / 'models.json'

if model_json.exists():
  with open(model_json, 'r') as f:
    models = json.load(f)

  for values in models.values():
    for key, value in list(values.items()):
      civit_api = re.compile(r'^https://civitai\.com/api/download/models/\d{6,7}\?(?:[\w\d]+=[\w\d\-_\.]+&?)*$')
      if value.startswith('https://civitai.com/'):
        if civit_api.match(value):
          # Get front half
          begin = re.compile(r'https.*(?=\?)')
          end = re.compile(r'(?<=\?).*$')
          front = begin.search(value).group(0)
          back = end.search(value).group(0)
          values[key] = f'{front}?token={CIVITAI_TOKEN.value}&{back}'
        else:
          print('❌ A model link is improperly formatted...')


selected_model_type = widgets.Output()
selected_model_name = widgets.Output()

vbox_selected = widgets.VBox([selected_model_type, selected_model_name])

model_type = None
model_name = None

def select_model(m_type, m_name):
    global model_type
    global model_name
    model_type = m_type
    model_name = m_name
    time.sleep(0.5)
    with selected_model_type:
        selected_model_type.clear_output()
        print(f"{emojis['bullet']} Model Type: {m_type}")

    with selected_model_name:
        selected_model_name.clear_output()
        print(f"{emojis['bullet']} Model Name: {m_name}")

accordion_items = []
for model_type, model_entries in models.items():
  buttons = []
  for model_name in model_entries.keys():
    btn = widgets.Button(description=model_name,
                         layout=widgets.Layout(width='auto'))
    btn.on_click(lambda b, t=model_type, n=model_name: select_model(t, n))
    buttons.append(btn)
  vbox = widgets.VBox(buttons)
  accordion_items.append((model_type, vbox))

accordion = widgets.Accordion(children=[item[1] for item in accordion_items],
                              layout=widgets.Layout(width='50%'),
                              )
for i, item in enumerate(accordion_items):
  accordion.set_title(i, item[0])

accordion.selected_index = None

model_instance = widgets.Button(description='Create Pipeline', button_style='success')

vae = {'fixed': 'madebyollin/sdxl-vae-fp16-fix',
       'old': 'stabilityai/sdxl-vae'}

run_vae = widgets.Checkbox(value=False, description='Set VAE')

choose_vae = widgets.Dropdown(
    options=['fixed', 'old'],
    description='VAE :',
    disabled=False,
    layout=widgets.Layout(width='400px')
    )

hbox_choose = widgets.HBox([accordion, vbox_selected])

hbox_run = widgets.HBox([model_instance, run_vae])
vae_vbox = widgets.VBox()

def show_vae(b):
  if run_vae.value == True:
    vae_vbox.children = [choose_vae]
  else:
    vae_vbox.children = False

run_vae.observe(show_vae)

pipe_output = widgets.Output()

pipe = None

def create_pipeline(b):

  global pipe
  model_location = models[model_type][model_name]

  if 'img2img' in version:
    SDXLPipe=StableDiffusionXLImg2ImgPipeline
  else:
    SDXLPipe=StableDiffusionXLPipeline

  hf_repo = re.compile(r'[\w\d_-]+/[\w\d_-]')
  local = re.compile(r'/content.*$')
  civitai = re.compile(r'https://civitai.*')
  print(f'{emojis["loading"]} Loading Pipeline...')
  if hf_repo.match(model_location):
    with pipe_output:
      pipe_output.clear_output()
      print(f'{emojis["loading"]} Loading Pipeline...')
    pipe = SDXLPipe.from_pretrained(
      model_location,
      torch_dtype=DTYPE,
      ).to(DEVICE.value)

  elif local.match(model_location):
    with pipe_output:
      pipe_output.clear_output()
      print(f'{emojis["loading"]} Loading Pipeline...')
    pipe = SDXLPipe.from_single_file(
      model_location,
      torch_dtype=DTYPE,
      use_safetensors=True,
      ).to(DEVICE.value)

  elif civitai.match(model_location):
    with pipe_output:
      pipe_output.clear_output()
      print(f'{emojis["loading"]} Loading Pipeline...')
    run = subprocess.run(
      ['wget', '-O', f'/content/models/{model_name}.safetensors', f'{model_location}'],
      capture_output=True, text=True
      )
    if run.returncode != 0:
      raise ValueError(f'Failed to download civitai model: {run.stderr}')
    else:
      pipe = SDXLPipe.from_single_file(
        f'/content/models/{model_name}.safetensors',
        torch_dtype=DTYPE,
        ).to(DEVICE.value)
    with pipe_output:
      pipe_output.clear_output()
      print(f'{emojis["success"]} Loaded Pipeline')
  else:
    print(f'{emojis["error"]} Problem loading the model, check model path and try again')
  # VAE
  if run_vae.value == True:
    pipe.vae = AutoencoderKL.from_pretrained(
      vae[choose_vae.value],
      torch_dtype=DTYPE,
      use_safetensors=True,
    ).to('cuda')

model_instance.on_click(create_pipeline)

display(hbox_choose, hbox_run, vae_vbox)

In [None]:
#@title Embeddings
!if [ ! -e "/content/embeddings" ]; then mkdir "/content/embeddings"; fi

embedding_output = widgets.Output()

with open(resources / 'embeddings.json') as embed_file:
  embeddings = json.load(embed_file)

embedding_dict = {}
embedding_tokens = {'pos': [], 'neg': []}

for embedding in embeddings:
  name = list(embedding.keys())[0]
  embedding_dict[name] = embedding[name]


for value in embedding_dict.values():
  for value in value.values():
    url_start = re.compile(r'https://civitai\.com/api/download/models/\d{6,7}')
    url_end = re.compile(r'(?<=\?).*$')
    if url_start.match(value['link']):
      value['link'] = url_start.search(value['link']).group(0) + '?' + f'token={CIVITAI_TOKEN.value}&' + url_end.search(value['link']).group(0)


combo_options = [name for name in embedding_dict.keys()]

embeddings_combo = widgets.Combobox(
    options=combo_options,
    description='Embedding :',
    placeholder='Double-click for options or enter path/url',
    disabled=False,
    layout=widgets.Layout(width='400px')
    )

add_embedding = widgets.Button(description='Add Embedding', button_style='info')

token_text = widgets.Text(placeholder='Enter Embedding Token')

pos_neg_radio = widgets.RadioButtons(
    options=['pos', 'neg'],
    description='Embedding Type:',
    disabled=False,
    layout=widgets.Layout(width='400px')
    )

token_vbox = widgets.VBox()

def get_token(change):
  if embeddings_combo.value not in combo_options:
    if re.match(r'(https:/)|(/content)', embeddings_combo.value):
      token_vbox.children = (token_text, pos_neg_radio)
    else:
      token_vbox.children = []
  else:
    token_vbox.children = []

embeddings_combo.observe(get_token)

embedding_button = widgets.Button(
    description='Add Embedding', button_style='info')

def add_embedding(b):
  with embedding_output:
    embedding_output.clear_output()
    print(f'{emojis["loading"]} Adding {embeddings_combo.value}...')
  pos_neg_radio.value = 'pos'
  name = embeddings_combo.value
  global embedding_tokens
  if embeddings_combo.value not in combo_options:
    if not token_text:
      with embedding_output:
        embedding_output.clear_output()
        print(f'{emojis["error"]} You must include a token with embedding')
    else:
      output = f'/content/embeddings/{token_text}.safetensors'
      if name.startswith('https://civitai'):

        run = subprocess.run(
                ['wget', '-O', output, name],
                capture_output=True, text=True
              )
        if run.returncode != 0:
                with embedding_output:
                    embedding_output.clear_output()
                    print(f'{emojis["error"]} Failed to download {name}')
                return

      elif name.startswith('/content'):
        run = subprocess.run(
                ['cp', name, output],
                capture_output=True, text=True
              )
        if run.returncode != 0:
                with embedding_output:
                    embedding_output.clear_output()
                    print(f'{emojis["error"]} Failed to copy {name}')
                return

        if not Path(output).exists():
            with embedding_output:
                embedding_output.clear_output()
                print(f'{emojis["error"]} Embedding file not found: {output}')
            return

        create_embedding(pipe, output, token_text)
        embedding_tokens[pos_neg_radio.value].append(token_text)
      else:
        with embedding_output:
          embedding_output.clear_output()
          print(f'{emojis["error"]} Failed to find embedding, check path/url')
  else:
    if embedding_dict[name]['positive']:

      output = f'/content/embeddings/{name}_pos.safetensors'

      link = embedding_dict[name]['positive']['link']
      token = embedding_dict[name]['positive']['token']
      run = subprocess.run(
          ['wget', '-O', output, link],
          capture_output=True, text=True
          )

      if run.returncode != 0:
                print(f'{emojis["error"]} Failed to download positive embedding: {link}')
                return

      print(f'✅ Loaded embedding token: {token}')
      create_embedding(pipe, output, token)
      embedding_tokens['pos'].append(token)

    if embedding_dict[name]['negative']:
      output = f'/content/embeddings/{name}_neg.safetensors'
      link = embedding_dict[name]['negative']['link']
      token = embedding_dict[name]['negative']['token']
      run = subprocess.run(
        ['wget', '-O', output, link],
        capture_output=True, text=True
        )

      if run.returncode != 0:
                print(f'{emojis["error"]} Failed to download positive embedding: {link}')
                return

      print(f'✅ Loaded embedding token: {token}')
      create_embedding(pipe, output, token)
      embedding_tokens['neg'].append(token)

  print(f'Current Embedding Tokens: {embedding_tokens}')
  with embedding_output:
    embedding_output.clear_output()
    print(f'{emojis["success"]} Added {embeddings_combo.value}')

embedding_button.on_click(add_embedding)
embedding_hbox = widgets.HBox([embeddings_combo, embedding_button])
display(embedding_hbox, token_vbox, embedding_output)

In [None]:
#@title LoRAs
# import importlib
# importlib.reload(utilities.managers)
from utilities.managers import LoraManager

# Do mostly UI here, instantiate automatically and then control:
# add, remove, weight change, delete, clear weights, view weights
lora_manager = LoraManager(pipe, civitai_token=CIVITAI_TOKEN.value)

with open('/content/toolkit/resources/loras.json') as lora_file:
  loras = json.load(lora_file)

lora_dict = {lora['name']: lora for lora in loras}


for key in lora_dict.keys():
  url_start = re.compile(r'https://civitai\.com/api/download/model/\d{6,7}')
  lora_dict[key]['link'] = lora_dict[key]['link'] + f'?token={CIVITAI_TOKEN.value}'

items = [widgets.Label(name) for name in lora_dict.keys()]
lora_list = widgets.GridBox(items, layout=widgets.Layout(grid_template_columns="repeat(5, 200px)"))

lora_heading = widgets.HTML(
    value='<h4 style="color: lightblue;"">Get Lora Info:</h4>'
)

lora_input = widgets.Text(placeholder="Enter lora name here")
lora_search_button = widgets.Button(description="Check LoRA info", button_style='info')

lora_hbox = widgets.HBox([lora_input, lora_search_button])
lora_output = widgets.Output()

def check_lora(b):
  name = lora_input.value.strip()
  if name not in lora_dict.keys():
    with lora_output:
      lora_output.clear_output(wait=True)
      print("❌ LoRA not found, check the list and try again.")
    return
  else:
    with lora_output:
      lora_output.clear_output(wait=True)
      print(f"🔹 Name: {name}")
      print(f"🔹 Link: {lora_dict[name]['link']}")
      print(f"🔹 Triggers: {lora_dict[name]['triggers']}")
      print(f"🔹 Model: {lora_dict[name]['model']}")

lora_search_button.on_click(check_lora)
display(lora_list, lora_heading, lora_hbox, lora_output)

In [None]:
#@title Lora Manager
lora_names = [name for name in lora_manager.list_loras().keys()]

lora_names_heading = widgets.HTML(
    value='<h4 style="color: lightblue;">Loaded LoRAs:</h4>'
)


lora_select = widgets.Select(
    options=lora_names,
    disabled=False,
    layout=widgets.Layout(width='400px')
)

lora_text = widgets.Text(placeholder="Enter lora name/path here",
                         layout=widgets.Layout(width='400px'),
                         )
lora_namer = widgets.Text(placeholder='Enter name for new LoRA',
                          layout=widgets.Layout(width='400px')
                          )

name_vbox = widgets.VBox()

def show_namer(change):
  if lora_text.value.startswith('https://civitai'):
    name_vbox.children = [lora_namer]
  else:
    name_vbox.children = []

lora_text.observe(show_namer)


# Buttons
add_lora = widgets.Button(description='Add LoRA', button_style='success')
change_weight_button = widgets.Button(description='Change Weight', button_style='info', layout=widgets.Layout(width='150px', justify='flex-start'))
remove_lora_button = widgets.Button(description='Remove LoRA', button_style='danger')

# Output
lora_output = widgets.Output()

weight_input = widgets.BoundedFloatText(
                                  description="Weight:",
                                  value=0.0,
                                  min=-4.0,
                                  max=4.0,
                                  step=0.1
                                  )
def select_change(change):
  if lora_select.value:
    weight_input.value = lora_manager.loras[lora_select.value].weight
  else:
    weight_input.value = 0.0

lora_select.observe(select_change)

def add_lora_button(b):
  if lora_text.value.startswith('/content'):
    lora_local = Path(lora_text.value)
    if lora_local.exists():
        with lora_output:
          lora_output.clear_output(wait=True)
          print(f'{emojis["loading"]} Adding {lora_local.name}...')
        lora_manager.add_lora(str(lora_local), lora_local.stem)
        lora_select.options = [name for name in lora_manager.list_loras().keys()]
        with lora_output:
          lora_output.clear_output(wait=True)
          print(f'{emojis["success"]} Added {lora_text.value}')
    else:
      with lora_output:
        lora_output.clear_output(wait=True)
        print(f'{emojis["error"]} LoRA not found, check path')
  elif lora_text.value.startswith('https://civitai'):
    if lora_namer.value == '':
      with lora_output:
        lora_output.clear_output(wait=True)
        print(f'{emojis["error"]} LoRA name cannot be empty')
      return
    url_start = re.compile(r'https://civitai\.com/api/download/models/\d{6,7}')
    url_end = re.compile(r'(?<=\?).*$')
    if url_start.search(lora_text.value):

      if url_end.search(lora_text.value):

        civit_url = url_start.search(lora_text.value).group(0) + f'?token={CIVITAI_TOKEN.value}&' + url_end.search(lora_text.value).group(0)
      else:
        civit_url = url_start.search(lora_text.value).group(0) + f'?token={CIVITAI_TOKEN.value}'
      with lora_output:
        lora_output.clear_output(wait=True)
        print(f'{emojis["loading"]} Adding {lora_namer.value}...')
      destination = f'/content/loras/{lora_namer.value}.safetensors'
      run = subprocess.run(['wget', '-O', destination, civit_url], capture_output=True, text=True)
      if run.returncode == 0:
        lora_manager.add_lora(destination, lora_namer.value)
        lora_select.options = [name for name in lora_manager.list_loras().keys()]
        with lora_output:
          lora_output.clear_output(wait=True)
          print(f'{emojis["success"]} Added {lora_namer.value}')
    else:
      with lora_output:
        lora_output.clear_output(wait=True)
        print(f'{emojis["error"]} LoRA not found, check link')
  elif lora_text.value in lora_dict.keys():
    with lora_output:
      lora_output.clear_output(wait=True)
      print(f'{emojis["loading"]} Adding {lora_text.value}...')
    lora_manager.add_lora(lora_dict[lora_text.value]['link'], lora_text.value)
    lora_select.options = [name for name in lora_manager.list_loras().keys()]
    with lora_output:
      lora_output.clear_output(wait=True)
      print(f'{emojis["success"]} Added {lora_text.value}')

  else:
    with lora_output:
      print(f'{emojis["error"]} LoRA not found')


def change_weight(b):
  lora_manager.loras[lora_select.value].change_weight(weight_input.value)
  with lora_output:
    lora_output.clear_output(wait=True)
    print(f'{emojis["success"]} Updated weight for {lora_select.value}')

def remove_lora(b):
  lora_manager.delete_lora(lora_select.value)
  lora_select.options = [name for name in lora_manager.list_loras().keys()]
  with lora_output:
    lora_output.clear_output(wait=True)
    print(f'{emojis["success"]} Removed {lora_select.value}')


# lora_select.observe(select_change)

add_lora.on_click(add_lora_button)
change_weight_button.on_click(change_weight)
remove_lora_button.on_click(remove_lora)

buttons_hbox = widgets.HBox([add_lora, change_weight_button, remove_lora_button],
                            layout=widgets.Layout(width='400px', justify_content='space-between')
                            )

display(lora_names_heading, lora_text, name_vbox,
        lora_select,  weight_input, buttons_hbox,
        lora_output)

In [None]:
#@title Prompts
import pandas as pd

# Prompter
prompter_header = widgets.HTML(
    value='<h3 style="color: lightblue;">Prompts:</h3><p style="color: lightblue;"><b>Add tags to persist across prompts:</b></p>'
)

initial_tags = widgets.Text(
    placeholder="Add tags here..."
)

create_prompter_button = widgets.Button(
    description='Create Prompter', button_style='info'
    )

pony_check = widgets.Checkbox(
    value=False,
    description='Pony prompts',
    disabled=False,
    indent=False
)

prompter_output = widgets.Output()

prompter_first_vbox = widgets.VBox([prompter_header,
                                   initial_tags,
                                   create_prompter_button,
                                   pony_check,
                                   ],
                                   layout=widgets.Layout(width='auto'))

prompter = None
df = pd.read_csv(resources / 'danbooru-tags.csv')

def create_prompter(b):
  global prompter
  with prompter_output:
    prompter_output.clear_output()
    print(f'{emojis["loading"]} Creating Prompter...')
  prompter = PromptManager(pipe,
                           initial_tags.value,
                           embedding_tokens['pos'],
                           embedding_tokens['neg'],
                           df=df,
                           pony=pony_check.value
                           )
  with prompter_output:
    prompter_output.clear_output()
    print(f'{emojis["success"]} Created Prompter')

prompt_vbox=widgets.VBox()

prompt_text = widgets.Textarea(
    placeholder='Enter prompt here...',
    layout=widgets.Layout(width='400px', height='200px')
)

prompt_create_button = widgets.Button(
    description='Create Prompt', button_style='success'
)
negative_prompt = widgets.Text(
    placeholder='Enter negative prompt here...',
    layout=widgets.Layout(width='400px')
)
rand_check = widgets.Checkbox(
    value=False,
    description='random tags',
    disabled=False,
    indent=False
)
shuffle_check = widgets.Checkbox(
    value=False,
    description='shuffle tags',
    disabled=False,
    indent=False
)

add_lora_triggers_b = widgets.Button(
    description='Add LoRA Triggers', button_style='info'
)

hbox_neg_and_rand = widgets.HBox([rand_check, shuffle_check])

hbox_buttons = widgets.HBox([prompt_create_button, add_lora_triggers_b])

randan_vbox = widgets.VBox()

def on_prompter_create(b):
  prompt_vbox.children = [prompt_text, negative_prompt, hbox_neg_and_rand, hbox_buttons]
  randan_hbox.children = [randan_count, randan_button]
  randan_vbox.children = [randan_heading, randan_hbox, randan_output]
prompt = None

def create_prompt(b):
  global prompt
  prompt = prompter.create_prompt(prompt_text.value, negative_prompt.value, rand_tags=rand_check.value, shuffle=shuffle_check.value)
  with prompter_output:
    prompter_output.clear_output()
    print(f'{emojis["success"]} Created Prompt')
    print(f'{emojis["bullet"]} Positive: {prompter.pos_prompt}')
    print(f'{emojis["bullet"]} Negative: {prompter.neg_prompt}')

def add_lora_triggers(b):
  lora_names = [name for name in lora_manager.list_loras().keys()]
  triggers = []
  for name in filter(lambda x: x != '', lora_names):
    if name in lora_dict.keys() and lora_dict[name]['triggers'] != '':
      triggers.append(lora_dict[name]['triggers'])
  prompt_text.value += ', '.join(triggers)

full_prompt_vbox = widgets.VBox([prompter_first_vbox, prompt_vbox])

# Random Tags creator
randan_heading = widgets.HTML(
    value='<h4 style="color: lightblue;">Tags:</h4><p style="color: lightblue;"><b>Generate Random Tags:</b></p>'
)
randan_button = widgets.Button(description='Get Random Tags', button_style='info',
                               layout=widgets.Layout(width='100px'))

randan_count = widgets.BoundedIntText(
    value=10,
    min=5,
    max=50,
    step=1,
    disabled=False,
    layout=widgets.Layout(width='100px')
)
randan_hbox = widgets.HBox([randan_button, randan_count])


def randan_get(b):
  random_tags = prompter.randan(count=randan_count.value)
  if isinstance(random_tags,str):
    random_tags = random_tags.split(', ')
  with randan_output:
    randan_output.clear_output()
    for tag in random_tags:
      print(f'{emojis["bullet"]} {tag}')

randan_button.on_click(randan_get)
randan_output = widgets.Output()


full_hbox = widgets.HBox([full_prompt_vbox, randan_vbox], layout=widgets.Layout(width="75%"))
# Events
add_lora_triggers_b.on_click(add_lora_triggers)
prompt_create_button.on_click(create_prompt)
create_prompter_button.on_click(create_prompter)
create_prompter_button.on_click(on_prompter_create)

display(full_hbox, prompter_output)

In [None]:
#@title Generation

production = True
# Schedulers
scheduler_heading = widgets.HTML(
    value='<h4 style="color: lightblue;">Schedulers</h4>'
)
scheduler_types = ['none', 'euler', 'dpm', 'dpm_sde', 'pndm', 'ddim', 'lcm']
scheduler_dropdown = widgets.Dropdown(
    options=scheduler_types,
    placeholder='choose scheduler',
    layout=widgets.Layout(width='150px')
)

spacings = ['linspace', 'leading', 'trailing']
spacings_select = widgets.Select(
    options=spacings,
    disabled=False,
    rows=3,
    layout=widgets.Layout(width='100px')
)
sched_hbox = widgets.HBox([scheduler_dropdown])
sched_output = widgets.Output()

# Upscale
def show_spacings(change):
  if scheduler_dropdown.value == 'euler':
    sched_hbox.children = [scheduler_dropdown, spacings_select]
  else:
    sched_hbox.children = [scheduler_dropdown]

scheduler_dropdown.observe(show_spacings)
set_sched_button = widgets.Button(description='Set Scheduler', button_style='info')

def set_schedulers(b):
  if scheduler_dropdown.value == 'none':
    with sched_output:
      sched_output.clear_output()
      print(f'{emojis["warning"]} You should really choose a scheduler!')
    return
  if scheduler_dropdown.value == 'euler':
    schedulers.set_euler_scheduler(pipe, spacing=spacings_select.value)
  elif scheduler_dropdown.value == 'dpm':
    schedulers.set_dpm_scheduler(pipe)
  elif scheduler_dropdown.value == 'dpm_sde':
    schedulers.set_dpm_sde_scheduler(pipe)
  elif scheduler_dropdown.value == 'pndm':
    schedulers.set_pndm_scheduler(pipe)
  elif scheduler_dropdown.value == 'ddim':
    schedulers.set_ddim_scheduler(pipe)
  elif scheduler_dropdown.value == 'lcm':
    schedulers.set_lcm_scheduler(pipe)
  with sched_output:
    sched_output.clear_output()
    print(f'{emojis["success"]} Scheduler set to {scheduler_dropdown.value}')

set_sched_button.on_click(set_schedulers)
display(scheduler_heading, sched_hbox, set_sched_button, sched_output)

from PIL import Image
import io

generator = ImageGenerator(pipe, upscale)

generate_header = widgets.HTML(
    value='<h3 style="color: lightblue;">Generate Images:</h3>'
)

# width, height -> dropdown dims = [640, 768, 832, 896, 896, 1024, 1152, 1216, 1536]
dims = [640, 768, 832, 896, 896, 1024, 1152, 1216, 1536]

width_dropdown = widgets.Dropdown(
    options=dims,
    value=1024,
    description='Width:',
    disabled=False,
    layout=widgets.Layout(width='300px')
)

height_dropdown = widgets.Dropdown(
    options=dims,
    value=1024,
    description='Height:',
    disabled=False,
    layout=widgets.Layout(width='300px')
)

# num_imgs -> BoundedIntText

num_imgs_text = widgets.BoundedIntText(
    description='Images:',
    value=4,
    min=1,
    max=8,
    step=1,
    layout=widgets.Layout(width='300px')
)
# steps -> Slider (second for hi_res)
steps_slider = widgets.BoundedIntText(
    description='Steps:',
    value=25,
    min=5,
    max=50,
    step=1,
)
# cfg -> Slider
cfg_slider = widgets.FloatSlider(
    description='CFG:',
    value=7.0,
    min=1.0,
    max=12.0,
    step=0.02,
    layout=widgets.Layout(width='325px')
)
# strenght -> Slider
strength_slider = widgets.FloatSlider(
    description='Strength:',
    value=0.0,
    min=0.0,
    max=1.0,
    step=0.02,
)
# scale -> BoundedIntText
scale_text = widgets.BoundedIntText(
    description='Scale:',
    value=2,
    min=1,
    max=4,
    step=1,
)
# seed -> Text
seed_text = widgets.Text(
    placeholder='Enter seed here...',
    layout=widgets.Layout(width='275px')
)
# clipSkip -> Radio
clip_skip_drop = widgets.Dropdown(
    description='Clip-skip',
    options=['0', '1', '2'],
    value='1',
    layout=widgets.Layout(width='300px', flex_flow='row', flex_direction='row')
)

gen_type = widgets.ToggleButtons(
    options=['txt2img', 'img2img', 'hiRes', 'gfpgan', 'upscale'],
    value='txt2img',
    disabled=False,
    button_style='info',
    layout=widgets.Layout(width='200px', flex_flow='column', flex_direction='column', align_items='stretch'),
)

file_upload = widgets.FileUpload(
    accept='*',  # Accept image files
    multiple=False,
    layout=widgets.Layout(width='300px')
)

file_path = widgets.Text(
    placeholder='Enter Local File Path',
    description='File Path: ',
    layout=widgets.Layout(width='300px')
)

gen_output = widgets.Output()

vertical_options = widgets.VBox([width_dropdown, height_dropdown,
        num_imgs_text,
        steps_slider,
        cfg_slider,
        ], layout=widgets.Layout(display='flex', align='center'))

row3_vbox = widgets.VBox([clip_skip_drop])
hbox_dual = widgets.HBox([gen_type, vertical_options, row3_vbox])

generate_button = widgets.Button(description='Generate', button_style='success')

upload_header = widgets.HTML(
    value='<h4 style="color: lightblue;">Upload Image:</h4>',
    layout=widgets.Layout(width='300px')
)


def add_image_options(change):
  if gen_type.value == 'img2img':
    row3_vbox.children = [clip_skip_drop, upload_header,
                          file_upload, file_path, strength_slider]
  elif gen_type.value == 'hiRes':
    row3_vbox.children = [clip_skip_drop, scale_text, strength_slider]
  elif gen_type.value in ['gfpgan', 'upscale']:
    row3_vbox.children = [clip_skip_drop, scale_text]
  else:
    row3_vbox.children = [clip_skip_drop]


def update_config(change):
  global generator
  generator.config.cfg = float(cfg_slider.value)
  generator.config.steps = (int(steps_slider.value), int(steps_slider.value))
  generator.config.clip_skip = int(clip_skip_drop.value)
  generator.config.strength = float(strength_slider.value)
  generator.config.width = int(width_dropdown.value)
  generator.config.height = int(height_dropdown.value)
  generator.config.num_images = int(num_imgs_text.value)
  generator.config.scale = int(scale_text.value)


# Events
gen_type.observe(add_image_options)
cfg_slider.observe(update_config)
steps_slider.observe(update_config)
clip_skip_drop.observe(update_config)
seed_text.observe(update_config)
strength_slider.observe(update_config)
width_dropdown.observe(update_config)
height_dropdown.observe(update_config)
num_imgs_text.observe(update_config)
scale_text.observe(update_config)

images = None

def free_memory():
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()


def generate_images(b):
  global generator
  global images
  global prompt
  if seed_text.value != '':
    seed = int(seed_text.value)
  else:
    seed = None
  with gen_output:
    gen_output.clear_output(wait=True)
    print(f'{emojis["loading"]} Generating images...')

  if gen_type.value == 'txt2img':
    if prompt is None:
      print(f"⚠️ Warning: Prompt is None! Creating a default placeholder.")
      prompt = ("placeholder prompt", "negative placeholder")

    images = generator.txt2img(prompt, seed=seed)

  elif gen_type.value == 'img2img':
    if 'img2img' not in version:
      with gen_output:
        print(f'{emojis["error"]} Image2Image Pipeline is disabled')
      return
    if file_upload.value == {}:
      if Path(file_path.value).exists():
        img = Image.open(file_path.value)
        images = generator.img2img(prompt, img, seed=seed)
      else:
        with gen_output:
          print(f'{emojis["error"]} File not found')
    else:
      img = Image.open(
          io.BytesIO(file_upload.value[list(file_upload.value.keys())[0]]['content'])
          )
      if production:
        images = generator.img2img(img, prompt, seed=seed)
      else:
        with gen_output:
          clear_output(wait=True)
          print(f'{emojis["warning"]} Image2Image Pipeline is disabled. To use, switch to "img2img" branch and reload model)"
        return
  elif gen_type.value == 'hiRes':
    if 'img2img' not in version:
      with gen_output:
        print(f'{emojis["error"]} HiRes Pipeline is disabled')
      return
    if production:
      images = generator.hi_res(prompt)
    else:
      with gen_output:
        clear_output(wait=True)
        print(f'{emojis["warning"]} HiRes Pipeline is disabled. To use, switch to "img2img" branch and reload model)"')
      return
  elif gen_type.value == 'gfpgan':
    images = generator.gfpgan(prompt)
  elif gen_type.value == 'upscale':
    images = generator.upscale(prompt)

  with gen_output:
    clear_output(wait=True)
    print(f'{emojis["loading"]} saving images...')
  img_paths = [img_util.save(img) for img in images]

  for img in img_paths:
    logger(img, LOG_PATH.value, model_name,
                   generator, prompter, lora_manager,
                   generator.config.current_seed)

  torch.cuda.empty_cache()
  with gen_output:
    clear_output(wait=True)
    grid_view.view_grid(images, 0.5)

generate_button.on_click(generate_images)

display(generate_header, hbox_dual, seed_text, generate_button, gen_output)