Skip to content

Commit

Permalink
enhance console gui for invokeai-merge (#2480)
Browse files Browse the repository at this point in the history
- Added modest adaptive behavior; if the screen is wide enough the three
checklists of models will be arranged in a horizontal row.
- Added color support
# What it looks like
On a wide window:

![image](https://user-images.githubusercontent.com/111189/216495149-0ceed761-b829-4b21-8e90-0b7faf2c7b72.png)
On a narrow window:

![image](https://user-images.githubusercontent.com/111189/216495239-1d6615cf-0e7e-44fe-83d7-513819635d8a.png)
  • Loading branch information
lstein committed Feb 3, 2023
2 parents d6bd0cb + fc857f9 commit 66b312c
Showing 1 changed file with 135 additions and 52 deletions.
187 changes: 135 additions & 52 deletions ldm/invoke/merge_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
"""
import argparse
import curses
import os
import sys
from argparse import Namespace
from pathlib import Path
from typing import List, Union

import npyscreen
import warnings
from diffusers import DiffusionPipeline
from omegaconf import OmegaConf

Expand All @@ -26,7 +28,6 @@

DEST_MERGED_MODEL_DIR = "merged_models"


def merge_diffusion_models(
model_ids_or_paths: List[Union[str, Path]],
alpha: float = 0.5,
Expand Down Expand Up @@ -185,6 +186,8 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):

def __init__(self, parentApp, name):
self.parentApp = parentApp
self.ALLOW_RESIZE=True
self.FIX_MINIMUM_SIZE_WHEN_CREATED=False
super().__init__(parentApp, name)

@property
Expand All @@ -195,29 +198,94 @@ def afterEditing(self):
self.parentApp.setNextForm(None)

def create(self):
window_height,window_width=curses.initscr().getmaxyx()

self.model_names = self.get_model_names()

max_width = max([len(x) for x in self.model_names])
max_width += 6
horizontal_layout = max_width*3 < window_width

self.add_widget_intelligent(
npyscreen.FixedText,
color='CONTROL',
value=f"Select two models to merge and optionally a third.",
editable=False,
)
self.add_widget_intelligent(
npyscreen.FixedText,
color='CONTROL',
value=f"Use up and down arrows to move, <space> to select an item, <tab> and <shift-tab> to move from one field to the next.",
editable=False,
)
self.add_widget_intelligent(
npyscreen.FixedText,
value='MODEL 1',
color='GOOD',
editable=False,
rely=4 if horizontal_layout else None,
)
self.model1 = self.add_widget_intelligent(
npyscreen.SelectOne,
values=self.model_names,
value=0,
max_height=len(self.model_names),
max_width=max_width,
scroll_exit=True,
rely=5,
)
self.add_widget_intelligent(
npyscreen.FixedText, name="Select up to three models to merge", value=""
npyscreen.FixedText,
value='MODEL 2',
color='GOOD',
editable=False,
relx=max_width+3 if horizontal_layout else None,
rely=4 if horizontal_layout else None,
)
self.models = self.add_widget_intelligent(
npyscreen.TitleMultiSelect,
name="Select two to three models to merge:",
self.model2 = self.add_widget_intelligent(
npyscreen.SelectOne,
name='(2)',
values=self.model_names,
value=None,
max_height=len(self.model_names) + 1,
value=1,
max_height=len(self.model_names),
max_width=max_width,
relx=max_width+3 if horizontal_layout else None,
rely=5 if horizontal_layout else None,
scroll_exit=True,
)
self.add_widget_intelligent(
npyscreen.FixedText,
value='MODEL 3',
color='GOOD',
editable=False,
relx=max_width*2+3 if horizontal_layout else None,
rely=4 if horizontal_layout else None,
)
models_plus_none = self.model_names.copy()
models_plus_none.insert(0,'None')
self.model3 = self.add_widget_intelligent(
npyscreen.SelectOne,
name='(3)',
values=models_plus_none,
value=0,
max_height=len(self.model_names)+1,
max_width=max_width,
scroll_exit=True,
relx=max_width*2+3 if horizontal_layout else None,
rely=5 if horizontal_layout else None,
)
self.models.when_value_edited = self.models_changed
for m in [self.model1,self.model2,self.model3]:
m.when_value_edited = self.models_changed
self.merged_model_name = self.add_widget_intelligent(
npyscreen.TitleText,
name="Name for merged model:",
labelColor='CONTROL',
value="",
scroll_exit=True,
)
self.force = self.add_widget_intelligent(
npyscreen.Checkbox,
name="Force merge of incompatible models",
labelColor='CONTROL',
value=False,
scroll_exit=True,
)
Expand All @@ -226,6 +294,7 @@ def create(self):
name="Merge Method:",
values=self.interpolations,
value=0,
labelColor='CONTROL',
max_height=len(self.interpolations) + 1,
scroll_exit=True,
)
Expand All @@ -236,47 +305,53 @@ def create(self):
step=0.05,
lowest=0,
value=0.5,
labelColor='CONTROL',
scroll_exit=True,
)
self.models.editing = True
self.model1.editing = True

def models_changed(self):
model_names = self.models.values
selected_models = self.models.value
if len(selected_models) > 3:
npyscreen.notify_confirm(
"Too many models selected for merging. Select two to three."
)
return
elif len(selected_models) > 2:
self.merge_method.values = ["add_difference"]
self.merge_method.value = 0
models = self.model1.values
selected_model1 = self.model1.value[0]
selected_model2 = self.model2.value[0]
selected_model3 = self.model3.value[0]
merged_model_name = f'{models[selected_model1]}+{models[selected_model2]}'
self.merged_model_name.value = merged_model_name

if selected_model3 > 0:
self.merge_method.values=['add_difference'],
self.merged_model_name.value += f'+{models[selected_model3]}'
else:
self.merge_method.values = self.interpolations
self.merged_model_name.value = "+".join(
[model_names[x] for x in selected_models]
)
self.merge_method.values=self.interpolations
self.merge_method.value=0

def on_ok(self):
if self.validate_field_values() and self.check_for_overwrite():
self.parentApp.setNextForm(None)
self.editing = False
self.parentApp.merge_arguments = self.marshall_arguments()
npyscreen.notify("Starting the merge...")
npyscreen.notify('Starting the merge...')
else:
self.editing = True

def on_cancel(self):
sys.exit(0)

def marshall_arguments(self) -> dict:
models = [self.models.values[x] for x in self.models.value]
def marshall_arguments(self)->dict:
model_names = self.model_names
models = [
model_names[self.model1.value[0]],
model_names[self.model2.value[0]],
]
if self.model3.value[0] > 0:
models.append(model_names[self.model3.value[0]-1])

args = dict(
models=models,
alpha=self.alpha.value,
interp=self.interpolations[self.merge_method.value[0]],
force=self.force.value,
merged_model_name=self.merged_model_name.value,
alpha = self.alpha.value,
interp = self.interpolations[self.merge_method.value[0]],
force = self.force.value,
merged_model_name = self.merged_model_name.value,
)
return args

Expand All @@ -289,15 +364,18 @@ def check_for_overwrite(self) -> bool:
f"The chosen merged model destination, {model_out}, is already in use. Overwrite?"
)

def validate_field_values(self) -> bool:
def validate_field_values(self)->bool:
bad_fields = []
selected_models = self.models.value
if len(selected_models) < 2 or len(selected_models) > 3:
bad_fields.append("Please select two or three models to merge.")
model_names = self.model_names
selected_models = set((model_names[self.model1.value[0]],model_names[self.model2.value[0]]))
if self.model3.value[0] > 0:
selected_models.add(model_names[self.model3.value[0]-1])
if len(selected_models) < 2:
bad_fields.append(f'Please select two or three DIFFERENT models to compare. You selected {selected_models}')
if len(bad_fields) > 0:
message = "The following problems were detected and must be corrected:"
message = 'The following problems were detected and must be corrected:'
for problem in bad_fields:
message += f"\n* {problem}"
message += f'\n* {problem}'
npyscreen.notify_confirm(message)
return False
else:
Expand All @@ -322,10 +400,9 @@ def __init__(self):
) # precision doesn't really matter here

def onStart(self):
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
npyscreen.setTheme(npyscreen.Themes.ElegantTheme)
self.main = self.addForm("MAIN", mergeModelsForm, name="Merge Models Settings")


def run_gui(args: Namespace):
mergeapp = Mergeapp()
mergeapp.run()
Expand All @@ -338,8 +415,8 @@ def run_gui(args: Namespace):
def run_cli(args: Namespace):
assert args.alpha >= 0 and args.alpha <= 1.0, "alpha must be between 0 and 1"
assert (
len(args.models) >= 1 and len(args.models) <= 3
), "provide 2 or 3 models to merge"
args.models and len(args.models) >= 1 and len(args.models) <= 3
), "Please provide the --models argument to list 2 to 3 models to merge. Use --help for full usage."

if not args.merged_model_name:
args.merged_model_name = "+".join(args.models)
Expand All @@ -353,6 +430,7 @@ def run_cli(args: Namespace):
), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'

merge_diffusion_models_and_commit(**vars(args))
print(f'>> Models merged into new model: "{args.merged_model_name}".')


def main():
Expand All @@ -365,17 +443,22 @@ def main():
] = cache_dir # because not clear the merge pipeline is honoring cache_dir
args.cache_dir = cache_dir

try:
if args.front_end:
run_gui(args)
else:
run_cli(args)
print(f">> Conversion successful. New model is named {args.merged_model_name}")
except Exception as e:
print(f"** An error occurred while merging the pipelines: {str(e)}")
sys.exit(-1)
except KeyboardInterrupt:
sys.exit(-1)
with warnings.catch_warnings():
warnings.simplefilter('ignore')
try:
if args.front_end:
run_gui(args)
else:
run_cli(args)
print(f'>> Conversion successful.')
except Exception as e:
if str(e).startswith('Not enough space'):
print('** Not enough horizontal space! Try making the window wider, or relaunch with a smaller starting size.')
else:
print(f"** An error occurred while merging the pipelines: {str(e)}")
sys.exit(-1)
except KeyboardInterrupt:
sys.exit(-1)

if __name__ == "__main__":
main()

0 comments on commit 66b312c

Please sign in to comment.