Skip to content

Commit

Permalink
Add customizable LoRA folder
Browse files Browse the repository at this point in the history
  • Loading branch information
JustMaier committed Mar 3, 2023
1 parent 811eb04 commit 579d694
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 4 deletions.
23 changes: 19 additions & 4 deletions civitai/lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,11 @@ def get_tags(query, page=1, page_size=20):
#endregion API

#region Get Utils
def get_lora_dir():
lora_dir = shared.opts.data.get('civitai_folder_lora', shared.cmd_opts.lora_dir).strip()
if not lora_dir: lora_dir = shared.cmd_opts.lora_dir
return lora_dir

def get_automatic_type(type: str):
if type == 'Hypernetwork': return 'hypernet'
return type.lower()
Expand Down Expand Up @@ -178,18 +183,18 @@ def get_resources_in_folder(type, folder, exts=[], exts_exclude=[]):
return resources

resources = []
def load_resource_list(types=['LORA', 'Hypernetwork', 'TextualInversion', 'Checkpoint', 'VAE']):
def load_resource_list(types=['LORA', 'Hypernetwork', 'TextualInversion', 'Checkpoint', 'VAE', 'Controlnet']):
global resources

# If resources is empty and types is empty, load all types
# This is a helper to be able to get the resource list without
# having to worry about initialization. On subsequent calls, no work will be done
if len(resources) == 0 and len(types) == 0:
types = ['LORA', 'Hypernetwork', 'TextualInversion', 'Checkpoint', 'VAE']
types = ['LORA', 'Hypernetwork', 'TextualInversion', 'Checkpoint', 'VAE', 'Controlnet']

if 'LORA' in types:
resources = [r for r in resources if r['type'] != 'LORA']
resources += get_resources_in_folder('LORA', shared.cmd_opts.lora_dir, ['pt', 'safetensors', 'ckpt'])
resources += get_resources_in_folder('LORA', get_lora_dir(), ['pt', 'safetensors', 'ckpt'])
if 'Hypernetwork' in types:
resources = [r for r in resources if r['type'] != 'Hypernetwork']
resources += get_resources_in_folder('Hypernetwork', shared.cmd_opts.hypernetwork_dir, ['pt', 'safetensors', 'ckpt'])
Expand All @@ -199,6 +204,9 @@ def load_resource_list(types=['LORA', 'Hypernetwork', 'TextualInversion', 'Check
if 'Checkpoint' in types:
resources = [r for r in resources if r['type'] != 'Checkpoint']
resources += get_resources_in_folder('Checkpoint', sd_models.model_path, ['safetensors', 'ckpt'], ['vae.safetensors', 'vae.ckpt'])
if 'Controlnet' in types:
resources = [r for r in resources if r['type'] != 'Controlnet']
resources += get_resources_in_folder('Controlnet', os.path.join(models_path, "ControlNet"), ['safetensors', 'ckpt'], ['vae.safetensors', 'vae.ckpt'])
if 'VAE' in types:
resources = [r for r in resources if r['type'] != 'VAE']
resources += get_resources_in_folder('VAE', sd_models.model_path, ['vae.pt', 'vae.safetensors', 'vae.ckpt'])
Expand Down Expand Up @@ -265,6 +273,7 @@ def load_resource(resource: ResourceRequest, on_progress=None):

if resource['type'] == 'Checkpoint': load_model(resource, on_progress)
elif resource['type'] == 'CheckpointConfig': load_model_config(resource, on_progress)
elif resource['type'] == 'Controlnet': load_controlnet(resource, on_progress)
elif resource['type'] == 'Hypernetwork': load_hypernetwork(resource, on_progress)
elif resource['type'] == 'TextualInversion': load_textual_inversion(resource, on_progress)
elif resource['type'] == 'LORA': load_lora(resource, on_progress)
Expand Down Expand Up @@ -301,11 +310,17 @@ def load_model(resource: ResourceRequest, on_progress=None):

return model

def load_controlnet(resource: ResourceRequest, on_progress=None):
isAvailable = load_if_missing(os.path.join(models_path, 'ControlNet', resource['name']), resource['url'], on_progress)
# TODO: reload controlnet list - not sure best way to import this
# if isAvailable is None:
# controlnet.list_available_models()

def load_textual_inversion(resource: ResourceRequest, on_progress=None):
load_if_missing(os.path.join(shared.cmd_opts.embeddings_dir, resource['name']), resource['url'], on_progress)

def load_lora(resource: ResourceRequest, on_progress=None):
isAvailable = load_if_missing(os.path.join(shared.cmd_opts.lora_dir, resource['name']), resource['url'], on_progress)
isAvailable = load_if_missing(os.path.join(get_lora_dir(), resource['name']), resource['url'], on_progress)
# TODO: reload lora list - not sure best way to import this
# if isAvailable is None:
# lora.list_available_loras()
Expand Down
14 changes: 14 additions & 0 deletions scripts/link.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,14 @@ def on_room_presence(payload: RoomPresence):
if connected: log("Connected to Civitai Instance")
else: log("Disconnected from Civitai Instance")

upgraded_key = None
@sio.on('upgradeKey')
def on_upgrade_key(payload: UpgradeKeyPayload):
global upgraded_key

log("Link Key upgraded")
shared.opts.data['civitai_link_key'] = payload['key']
upgraded_key = payload['key']

@sio.on('error')
def on_error(payload: ErrorPayload):
Expand Down Expand Up @@ -208,9 +212,19 @@ def connect_to_civitai(demo: gr.Blocks, app):
socketio_connect()
join_room(key)

old_short_key = None
def on_civitai_link_key_changed():
global old_short_key

if not sio.connected: socketio_connect()

# If the key is upgraded, don't change it back to the short key
if old_short_key is not None and old_short_key == shared.opts.data.get("civitai_link_key", None):
shared.opts.data['civitai_link_key'] = upgraded_key
return

key = shared.opts.data.get("civitai_link_key", None)
if len(key) < 10: old_short_key = key
join_room(key)
#endregion

Expand Down
2 changes: 2 additions & 0 deletions scripts/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,7 @@ def on_ui_settings():
shared.opts.add_option("civitai_nsfw_previews", shared.OptionInfo(False, "Download NSFW (adult) preview images", section=section))
shared.opts.add_option("civitai_download_missing_models", shared.OptionInfo(True, "Download missing models upon reading generation parameters from prompt", section=section))
shared.opts.add_option("civitai_hashify_resources", shared.OptionInfo(True, "Include resource hashes in image metadata (for resource auto-detection on Civitai)", section=section))
shared.opts.add_option("civitai_folder_lora", shared.OptionInfo("", "LoRA directory (if not default)", section=section))


script_callbacks.on_ui_settings(on_ui_settings)

0 comments on commit 579d694

Please sign in to comment.