Skip to content

Commit

Permalink
Cleanup .update deprecations warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
bmaltais committed Mar 15, 2024
1 parent 7081b22 commit 90a0298
Show file tree
Hide file tree
Showing 23 changed files with 121 additions and 86 deletions.
2 changes: 1 addition & 1 deletion kohya_gui/basic_caption_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def list_images_dirs(path):

# Event handler for dynamic update of dropdown choices
images_dir.change(
fn=lambda path: gr.Dropdown().update(choices=[""] + list_images_dirs(path)),
fn=lambda path: gr.Dropdown(choices=[""] + list_images_dirs(path)),
inputs=images_dir,
outputs=images_dir,
show_progress=False,
Expand Down
2 changes: 1 addition & 1 deletion kohya_gui/blip_caption_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def list_train_dirs(path):
)

train_data_dir.change(
fn=lambda path: gr.Dropdown().update(choices=[""] + list_train_dirs(path)),
fn=lambda path: gr.Dropdown(choices=[""] + list_train_dirs(path)),
inputs=train_data_dir,
outputs=train_data_dir,
show_progress=False,
Expand Down
4 changes: 2 additions & 2 deletions kohya_gui/class_configuration_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def __init__(self, headless=False, output_dir: gr.Dropdown = None):

def update_configs(output_dir):
self.output_dir = output_dir
return gr.Dropdown().update(choices=[""] + list(list_files(output_dir, exts=[".json"], all=True)))
return gr.Dropdown(choices=[""] + list(list_files(output_dir, exts=[".json"], all=True)))

def list_configs(path):
self.output_dir = path
Expand Down Expand Up @@ -47,7 +47,7 @@ def list_configs(path):
)

self.config_file_name.change(
fn=lambda path: gr.Dropdown().update(choices=[""] + list_configs(path)),
fn=lambda path: gr.Dropdown(choices=[""] + list_configs(path)),
inputs=self.config_file_name,
outputs=self.config_file_name,
show_progress=False,
Expand Down
6 changes: 3 additions & 3 deletions kohya_gui/class_folders.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,19 +91,19 @@ def list_logging_dirs(path):
)

self.output_dir.change(
fn=lambda path: gr.Dropdown().update(choices=[""] + list_output_dirs(path)),
fn=lambda path: gr.Dropdown(choices=[""] + list_output_dirs(path)),
inputs=self.output_dir,
outputs=self.output_dir,
show_progress=False,
)
self.reg_data_dir.change(
fn=lambda path: gr.Dropdown().update(choices=[""] + list_data_dirs(path)),
fn=lambda path: gr.Dropdown(choices=[""] + list_data_dirs(path)),
inputs=self.reg_data_dir,
outputs=self.reg_data_dir,
show_progress=False,
)
self.logging_dir.change(
fn=lambda path: gr.Dropdown().update(choices=[""] + list_logging_dirs(path)),
fn=lambda path: gr.Dropdown(choices=[""] + list_logging_dirs(path)),
inputs=self.logging_dir,
outputs=self.logging_dir,
show_progress=False,
Expand Down
2 changes: 1 addition & 1 deletion kohya_gui/class_source_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def list_train_dirs(path):
)

self.train_data_dir.change(
fn=lambda path: gr.Dropdown().update(choices=[""] + list_train_dirs(path)),
fn=lambda path: gr.Dropdown(choices=[""] + list_train_dirs(path)),
inputs=self.train_data_dir,
outputs=self.train_data_dir,
show_progress=False,
Expand Down
109 changes: 72 additions & 37 deletions kohya_gui/common_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,74 +55,109 @@
ENV_EXCLUSION = ["COLAB_GPU", "RUNPOD_POD_ID"]


def check_if_model_exist(output_name, output_dir, save_model_as, headless=False):
def check_if_model_exist(output_name: str, output_dir: str, save_model_as: str, headless: bool = False) -> bool:
'''
Checks if a model with the same name already exists and prompts the user to overwrite it if it does.
Parameters:
output_name (str): The name of the output model.
output_dir (str): The directory where the model is saved.
save_model_as (str): The format to save the model as.
headless (bool, optional): If True, skips the verification and returns False. Defaults to False.
Returns:
bool: True if the model already exists and the user chooses not to overwrite it, otherwise False.
'''
if headless:
log.info(
"Headless mode, skipping verification if model already exist... if model already exist it will be overwritten..."
'Headless mode, skipping verification if model already exist... if model already exist it will be overwritten...'
)
return False

if save_model_as in ["diffusers", "diffusers_safetendors"]:
if save_model_as in ['diffusers', 'diffusers_safetendors']:
ckpt_folder = os.path.join(output_dir, output_name)
if os.path.isdir(ckpt_folder):
msg = f"A diffuser model with the same name {ckpt_folder} already exists. Do you want to overwrite it?"
if not easygui.ynbox(msg, "Overwrite Existing Model?"):
log.info("Aborting training due to existing model with same name...")
msg = f'A diffuser model with the same name {ckpt_folder} already exists. Do you want to overwrite it?'
if not easygui.ynbox(msg, 'Overwrite Existing Model?'):
log.info('Aborting training due to existing model with same name...')
return True
elif save_model_as in ["ckpt", "safetensors"]:
ckpt_file = os.path.join(output_dir, output_name + "." + save_model_as)
elif save_model_as in ['ckpt', 'safetensors']:
ckpt_file = os.path.join(output_dir, output_name + '.' + save_model_as)
if os.path.isfile(ckpt_file):
msg = f"A model with the same file name {ckpt_file} already exists. Do you want to overwrite it?"
if not easygui.ynbox(msg, "Overwrite Existing Model?"):
log.info("Aborting training due to existing model with same name...")
msg = f'A model with the same file name {ckpt_file} already exists. Do you want to overwrite it?'
if not easygui.ynbox(msg, 'Overwrite Existing Model?'):
log.info('Aborting training due to existing model with same name...')
return True
else:
log.info(
'Can\'t verify if existing model exist when save model is set a "same as source model", continuing to train model...'
'Can\'t verify if existing model exist when save model is set as "same as source model", continuing to train model...'
)
return False

return False


def output_message(msg="", title="", headless=False):
def output_message(msg: str = "", title: str = "", headless: bool = False) -> None:
"""
Outputs a message to the user, either in a message box or in the log.
Parameters:
msg (str, optional): The message to be displayed. Defaults to an empty string.
title (str, optional): The title of the message box. Defaults to an empty string.
headless (bool, optional): If True, the message is logged instead of displayed in a message box. Defaults to False.
Returns:
None
"""
if headless:
log.info(msg)
else:
msgbox(msg=msg, title=title)


def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
# Converts refresh_component into a list for uniform processing. If it's already a list, keep it the same.
refresh_components = (
refresh_component
if isinstance(refresh_component, list)
else [refresh_component]
)

# Initialize label to None. This will store the label of the first component with a non-None label, if any.
label = None
# Iterate over each component to find the first non-None label and assign it to 'label'.
for comp in refresh_components:
label = getattr(comp, "label", None)
if label is not None:
break

# Define the refresh function that will be triggered upon clicking the refresh button.
def refresh():
# Invoke the refresh_method, which is intended to perform the refresh operation.
refresh_method()
# Determine the arguments for the refresh: call refreshed_args if it's callable, otherwise use it directly.
args = refreshed_args() if callable(refreshed_args) else refreshed_args

# For each key-value pair in args, update the corresponding properties of each component.
for k, v in args.items():
for comp in refresh_components:
setattr(comp, k, v)

# Use gr.update to refresh the UI components. If multiple components are present, update each; else, update only the first.
return (
[gr.update(**(args or {})) for _ in refresh_components]
[gr.Dropdown(**(args or {})) for _ in refresh_components]
if len(refresh_components) > 1
else gr.update(**(args or {}))
else gr.Dropdown(**(args or {}))
)

# Create a refresh button with the specified label (via refresh_symbol), ID, and classes.
# 'refresh_symbol' should be defined outside this function or passed as an argument, representing the button's label or icon.
refresh_button = gr.Button(
value=refresh_symbol, elem_id=elem_id, elem_classes=["tool"]
)
# Configure the button to invoke the refresh function.
refresh_button.click(fn=refresh, inputs=[], outputs=refresh_components)
# Return the configured refresh button to be used in the UI.
return refresh_button


Expand Down Expand Up @@ -525,9 +560,9 @@ def color_aug_changed(color_aug):
msgbox(
'Disabling "Cache latent" because "Color augmentation" has been selected...'
)
return gr.Checkbox.update(value=False, interactive=False)
return gr.Checkbox(value=False, interactive=False)
else:
return gr.Checkbox.update(value=True, interactive=True)
return gr.Checkbox(value=True, interactive=True)


def save_inference_file(output_dir, v2, v_parameterization, output_name):
Expand Down Expand Up @@ -568,11 +603,11 @@ def set_pretrained_model_name_or_path_input(
# Check if the given pretrained_model_name_or_path is in the list of SDXL models
if pretrained_model_name_or_path in SDXL_MODELS:
log.info("SDXL model selected. Setting sdxl parameters")
v2 = gr.Checkbox.update(value=False, visible=False)
v_parameterization = gr.Checkbox.update(value=False, visible=False)
sdxl = gr.Checkbox.update(value=True, visible=False)
v2 = gr.Checkbox(value=False, visible=False)
v_parameterization = gr.Checkbox(value=False, visible=False)
sdxl = gr.Checkbox(value=True, visible=False)
return (
gr.Dropdown().update(),
gr.Dropdown(),
v2,
v_parameterization,
sdxl,
Expand All @@ -581,11 +616,11 @@ def set_pretrained_model_name_or_path_input(
# Check if the given pretrained_model_name_or_path is in the list of V2 base models
if pretrained_model_name_or_path in V2_BASE_MODELS:
log.info("SD v2 base model selected. Setting --v2 parameter")
v2 = gr.Checkbox.update(value=True, visible=False)
v_parameterization = gr.Checkbox.update(value=False, visible=False)
sdxl = gr.Checkbox.update(value=False, visible=False)
v2 = gr.Checkbox(value=True, visible=False)
v_parameterization = gr.Checkbox(value=False, visible=False)
sdxl = gr.Checkbox(value=False, visible=False)
return (
gr.Dropdown().update(),
gr.Dropdown(),
v2,
v_parameterization,
sdxl,
Expand All @@ -596,11 +631,11 @@ def set_pretrained_model_name_or_path_input(
log.info(
"SD v2 model selected. Setting --v2 and --v_parameterization parameters"
)
v2 = gr.Checkbox.update(value=True, visible=False)
v_parameterization = gr.Checkbox.update(value=True, visible=False)
sdxl = gr.Checkbox.update(value=False, visible=False)
v2 = gr.Checkbox(value=True, visible=False)
v_parameterization = gr.Checkbox(value=True, visible=False)
sdxl = gr.Checkbox(value=False, visible=False)
return (
gr.Dropdown().update(),
gr.Dropdown(),
v2,
v_parameterization,
sdxl,
Expand All @@ -609,20 +644,20 @@ def set_pretrained_model_name_or_path_input(
# Check if the given pretrained_model_name_or_path is in the list of V1 models
if pretrained_model_name_or_path in V1_MODELS:
log.info(f"{pretrained_model_name_or_path} model selected.")
v2 = gr.Checkbox.update(value=False, visible=False)
v_parameterization = gr.Checkbox.update(value=False, visible=False)
sdxl = gr.Checkbox.update(value=False, visible=False)
v2 = gr.Checkbox(value=False, visible=False)
v_parameterization = gr.Checkbox(value=False, visible=False)
sdxl = gr.Checkbox(value=False, visible=False)
return (
gr.Dropdown().update(),
gr.Dropdown(),
v2,
v_parameterization,
sdxl,
)

# Check if the model_list is set to 'custom'
v2 = gr.Checkbox.update(visible=True)
v_parameterization = gr.Checkbox.update(visible=True)
sdxl = gr.Checkbox.update(visible=True)
v2 = gr.Checkbox(visible=True)
v_parameterization = gr.Checkbox(visible=True)
sdxl = gr.Checkbox(visible=True)

if refresh_method is not None:
args = dict(
Expand All @@ -631,7 +666,7 @@ def set_pretrained_model_name_or_path_input(
else:
args = {}
return (
gr.Dropdown().update(**args),
gr.Dropdown(**args),
v2,
v_parameterization,
sdxl,
Expand Down
4 changes: 2 additions & 2 deletions kohya_gui/convert_lcm_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,13 +133,13 @@ def list_save_to(path):
show_progress=False,
)
model_path.change(
fn=lambda path: gr.Dropdown().update(choices=[""] + list_models(path)),
fn=lambda path: gr.Dropdown(choices=[""] + list_models(path)),
inputs=model_path,
outputs=model_path,
show_progress=False,
)
name.change(
fn=lambda path: gr.Dropdown().update(choices=[""] + list_save_to(path)),
fn=lambda path: gr.Dropdown(choices=[""] + list_save_to(path)),
inputs=name,
outputs=name,
show_progress=False,
Expand Down
4 changes: 2 additions & 2 deletions kohya_gui/convert_model_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def list_target_folder(path):
)

source_model_input.change(
fn=lambda path: gr.Dropdown().update(choices=[""] + list_source_model(path)),
fn=lambda path: gr.Dropdown(choices=[""] + list_source_model(path)),
inputs=source_model_input,
outputs=source_model_input,
show_progress=False,
Expand Down Expand Up @@ -273,7 +273,7 @@ def list_target_folder(path):
)

target_model_folder_input.change(
fn=lambda path: gr.Dropdown().update(choices=[""] + list_target_folder(path)),
fn=lambda path: gr.Dropdown(choices=[""] + list_target_folder(path)),
inputs=target_model_folder_input,
outputs=target_model_folder_input,
show_progress=False,
Expand Down
2 changes: 1 addition & 1 deletion kohya_gui/dataset_balancing_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def list_dataset_dirs(path):
label='Training steps per concept per epoch',
)
select_dataset_folder_input.change(
fn=lambda path: gr.Dropdown().update(choices=[""] + list_dataset_dirs(path)),
fn=lambda path: gr.Dropdown(choices=[""] + list_dataset_dirs(path)),
inputs=select_dataset_folder_input,
outputs=select_dataset_folder_input,
show_progress=False,
Expand Down
6 changes: 3 additions & 3 deletions kohya_gui/dreambooth_folder_creation_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def list_train_data_dirs(path):
elem_id='number_input',
)
util_training_images_dir_input.change(
fn=lambda path: gr.Dropdown().update(choices=[""] + list_train_data_dirs(path)),
fn=lambda path: gr.Dropdown(choices=[""] + list_train_data_dirs(path)),
inputs=util_training_images_dir_input,
outputs=util_training_images_dir_input,
show_progress=False,
Expand Down Expand Up @@ -210,7 +210,7 @@ def list_reg_data_dirs(path):
elem_id='number_input',
)
util_regularization_images_dir_input.change(
fn=lambda path: gr.Dropdown().update(choices=[""] + list_reg_data_dirs(path)),
fn=lambda path: gr.Dropdown(choices=[""] + list_reg_data_dirs(path)),
inputs=util_regularization_images_dir_input,
outputs=util_regularization_images_dir_input,
show_progress=False,
Expand All @@ -236,7 +236,7 @@ def list_train_output_dirs(path):
get_folder_path, outputs=util_training_dir_output
)
util_training_dir_output.change(
fn=lambda path: gr.Dropdown().update(choices=[""] + list_train_output_dirs(path)),
fn=lambda path: gr.Dropdown(choices=[""] + list_train_output_dirs(path)),
inputs=util_training_dir_output,
outputs=util_training_dir_output,
show_progress=False,
Expand Down
4 changes: 2 additions & 2 deletions kohya_gui/extract_lora_from_dylora_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,13 @@ def list_save_to(path):
)

model.change(
fn=lambda path: gr.Dropdown().update(choices=[""] + list_models(path)),
fn=lambda path: gr.Dropdown(choices=[""] + list_models(path)),
inputs=model,
outputs=model,
show_progress=False,
)
save_to.change(
fn=lambda path: gr.Dropdown().update(choices=[""] + list_save_to(path)),
fn=lambda path: gr.Dropdown(choices=[""] + list_save_to(path)),
inputs=save_to,
outputs=save_to,
show_progress=False,
Expand Down

0 comments on commit 90a0298

Please sign in to comment.