diff --git a/ldm/invoke/merge_diffusers.py b/ldm/invoke/merge_diffusers.py index 85eac1077c3..00d2599e695 100644 --- a/ldm/invoke/merge_diffusers.py +++ b/ldm/invoke/merge_diffusers.py @@ -5,6 +5,7 @@ Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team """ import argparse +import curses import os import sys from argparse import Namespace @@ -12,6 +13,7 @@ from typing import List, Union import npyscreen +import warnings from diffusers import DiffusionPipeline from omegaconf import OmegaConf @@ -26,7 +28,6 @@ DEST_MERGED_MODEL_DIR = "merged_models" - def merge_diffusion_models( model_ids_or_paths: List[Union[str, Path]], alpha: float = 0.5, @@ -185,6 +186,8 @@ class mergeModelsForm(npyscreen.FormMultiPageAction): def __init__(self, parentApp, name): self.parentApp = parentApp + self.ALLOW_RESIZE=True + self.FIX_MINIMUM_SIZE_WHEN_CREATED=False super().__init__(parentApp, name) @property @@ -195,29 +198,94 @@ def afterEditing(self): self.parentApp.setNextForm(None) def create(self): + window_height,window_width=curses.initscr().getmaxyx() + self.model_names = self.get_model_names() - + max_width = max([len(x) for x in self.model_names]) + max_width += 6 + horizontal_layout = max_width*3 < window_width + + self.add_widget_intelligent( + npyscreen.FixedText, + color='CONTROL', + value=f"Select two models to merge and optionally a third.", + editable=False, + ) + self.add_widget_intelligent( + npyscreen.FixedText, + color='CONTROL', + value=f"Use up and down arrows to move, to select an item, and to move from one field to the next.", + editable=False, + ) + self.add_widget_intelligent( + npyscreen.FixedText, + value='MODEL 1', + color='GOOD', + editable=False, + rely=4 if horizontal_layout else None, + ) + self.model1 = self.add_widget_intelligent( + npyscreen.SelectOne, + values=self.model_names, + value=0, + max_height=len(self.model_names), + max_width=max_width, + scroll_exit=True, + rely=5, + ) self.add_widget_intelligent( - npyscreen.FixedText, name="Select up to three models to merge", value="" + npyscreen.FixedText, + value='MODEL 2', + color='GOOD', + editable=False, + relx=max_width+3 if horizontal_layout else None, + rely=4 if horizontal_layout else None, ) - self.models = self.add_widget_intelligent( - npyscreen.TitleMultiSelect, - name="Select two to three models to merge:", + self.model2 = self.add_widget_intelligent( + npyscreen.SelectOne, + name='(2)', values=self.model_names, - value=None, - max_height=len(self.model_names) + 1, + value=1, + max_height=len(self.model_names), + max_width=max_width, + relx=max_width+3 if horizontal_layout else None, + rely=5 if horizontal_layout else None, + scroll_exit=True, + ) + self.add_widget_intelligent( + npyscreen.FixedText, + value='MODEL 3', + color='GOOD', + editable=False, + relx=max_width*2+3 if horizontal_layout else None, + rely=4 if horizontal_layout else None, + ) + models_plus_none = self.model_names.copy() + models_plus_none.insert(0,'None') + self.model3 = self.add_widget_intelligent( + npyscreen.SelectOne, + name='(3)', + values=models_plus_none, + value=0, + max_height=len(self.model_names)+1, + max_width=max_width, scroll_exit=True, + relx=max_width*2+3 if horizontal_layout else None, + rely=5 if horizontal_layout else None, ) - self.models.when_value_edited = self.models_changed + for m in [self.model1,self.model2,self.model3]: + m.when_value_edited = self.models_changed self.merged_model_name = self.add_widget_intelligent( npyscreen.TitleText, name="Name for merged model:", + labelColor='CONTROL', value="", scroll_exit=True, ) self.force = self.add_widget_intelligent( npyscreen.Checkbox, name="Force merge of incompatible models", + labelColor='CONTROL', value=False, scroll_exit=True, ) @@ -226,6 +294,7 @@ def create(self): name="Merge Method:", values=self.interpolations, value=0, + labelColor='CONTROL', max_height=len(self.interpolations) + 1, scroll_exit=True, ) @@ -236,47 +305,53 @@ def create(self): step=0.05, lowest=0, value=0.5, + labelColor='CONTROL', scroll_exit=True, ) - self.models.editing = True + self.model1.editing = True def models_changed(self): - model_names = self.models.values - selected_models = self.models.value - if len(selected_models) > 3: - npyscreen.notify_confirm( - "Too many models selected for merging. Select two to three." - ) - return - elif len(selected_models) > 2: - self.merge_method.values = ["add_difference"] - self.merge_method.value = 0 + models = self.model1.values + selected_model1 = self.model1.value[0] + selected_model2 = self.model2.value[0] + selected_model3 = self.model3.value[0] + merged_model_name = f'{models[selected_model1]}+{models[selected_model2]}' + self.merged_model_name.value = merged_model_name + + if selected_model3 > 0: + self.merge_method.values=['add_difference'], + self.merged_model_name.value += f'+{models[selected_model3]}' else: - self.merge_method.values = self.interpolations - self.merged_model_name.value = "+".join( - [model_names[x] for x in selected_models] - ) + self.merge_method.values=self.interpolations + self.merge_method.value=0 def on_ok(self): if self.validate_field_values() and self.check_for_overwrite(): self.parentApp.setNextForm(None) self.editing = False self.parentApp.merge_arguments = self.marshall_arguments() - npyscreen.notify("Starting the merge...") + npyscreen.notify('Starting the merge...') else: self.editing = True def on_cancel(self): sys.exit(0) - def marshall_arguments(self) -> dict: - models = [self.models.values[x] for x in self.models.value] + def marshall_arguments(self)->dict: + model_names = self.model_names + models = [ + model_names[self.model1.value[0]], + model_names[self.model2.value[0]], + ] + if self.model3.value[0] > 0: + models.append(model_names[self.model3.value[0]-1]) + args = dict( models=models, - alpha=self.alpha.value, - interp=self.interpolations[self.merge_method.value[0]], - force=self.force.value, - merged_model_name=self.merged_model_name.value, + alpha = self.alpha.value, + interp = self.interpolations[self.merge_method.value[0]], + force = self.force.value, + merged_model_name = self.merged_model_name.value, ) return args @@ -289,15 +364,18 @@ def check_for_overwrite(self) -> bool: f"The chosen merged model destination, {model_out}, is already in use. Overwrite?" ) - def validate_field_values(self) -> bool: + def validate_field_values(self)->bool: bad_fields = [] - selected_models = self.models.value - if len(selected_models) < 2 or len(selected_models) > 3: - bad_fields.append("Please select two or three models to merge.") + model_names = self.model_names + selected_models = set((model_names[self.model1.value[0]],model_names[self.model2.value[0]])) + if self.model3.value[0] > 0: + selected_models.add(model_names[self.model3.value[0]-1]) + if len(selected_models) < 2: + bad_fields.append(f'Please select two or three DIFFERENT models to compare. You selected {selected_models}') if len(bad_fields) > 0: - message = "The following problems were detected and must be corrected:" + message = 'The following problems were detected and must be corrected:' for problem in bad_fields: - message += f"\n* {problem}" + message += f'\n* {problem}' npyscreen.notify_confirm(message) return False else: @@ -322,10 +400,9 @@ def __init__(self): ) # precision doesn't really matter here def onStart(self): - npyscreen.setTheme(npyscreen.Themes.DefaultTheme) + npyscreen.setTheme(npyscreen.Themes.ElegantTheme) self.main = self.addForm("MAIN", mergeModelsForm, name="Merge Models Settings") - def run_gui(args: Namespace): mergeapp = Mergeapp() mergeapp.run() @@ -338,8 +415,8 @@ def run_gui(args: Namespace): def run_cli(args: Namespace): assert args.alpha >= 0 and args.alpha <= 1.0, "alpha must be between 0 and 1" assert ( - len(args.models) >= 1 and len(args.models) <= 3 - ), "provide 2 or 3 models to merge" + args.models and len(args.models) >= 1 and len(args.models) <= 3 + ), "Please provide the --models argument to list 2 to 3 models to merge. Use --help for full usage." if not args.merged_model_name: args.merged_model_name = "+".join(args.models) @@ -353,6 +430,7 @@ def run_cli(args: Namespace): ), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.' merge_diffusion_models_and_commit(**vars(args)) + print(f'>> Models merged into new model: "{args.merged_model_name}".') def main(): @@ -365,17 +443,22 @@ def main(): ] = cache_dir # because not clear the merge pipeline is honoring cache_dir args.cache_dir = cache_dir - try: - if args.front_end: - run_gui(args) - else: - run_cli(args) - print(f">> Conversion successful. New model is named {args.merged_model_name}") - except Exception as e: - print(f"** An error occurred while merging the pipelines: {str(e)}") - sys.exit(-1) - except KeyboardInterrupt: - sys.exit(-1) + with warnings.catch_warnings(): + warnings.simplefilter('ignore') + try: + if args.front_end: + run_gui(args) + else: + run_cli(args) + print(f'>> Conversion successful.') + except Exception as e: + if str(e).startswith('Not enough space'): + print('** Not enough horizontal space! Try making the window wider, or relaunch with a smaller starting size.') + else: + print(f"** An error occurred while merging the pipelines: {str(e)}") + sys.exit(-1) + except KeyboardInterrupt: + sys.exit(-1) if __name__ == "__main__": main()