Skip to content

Commit

Permalink
Local embeddings support (CLI autocomplete) (#2211)
Browse files Browse the repository at this point in the history
* integrate local embeds with HF embeds

* Update concepts_lib.py

* Update concepts_lib.py

Co-authored-by: BuildTools <unconfigured@null.spigotmc.org>
Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
  • Loading branch information
3 people committed Jan 4, 2023
1 parent 6c6e534 commit 21bf512
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 13 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# ignore default image save location and model symbolic link
embeddings/
outputs/
models/ldm/stable-diffusion-v1/model.ckpt
**/restoration/codeformer/weights
Expand Down
61 changes: 51 additions & 10 deletions ldm/invoke/concepts_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __init__(self, root=None):
'''
self.root = root or Globals.root
self.hf_api = HfApi()
self.local_concepts = dict()
self.concept_list = None
self.concepts_loaded = dict()
self.triggers = dict() # concept name to trigger phrase
Expand All @@ -28,17 +29,28 @@ def __init__(self, root=None):

def list_concepts(self)->list:
'''
Return a list of all the concepts by name, without the 'sd-concepts-library' part.
Return a list of all the concepts by name, without the 'sd-concepts-library' part.
Also adds local concepts in invokeai/embeddings folder.
'''
local_concepts_now = self.get_local_concepts(os.path.join(self.root, 'embeddings'))
local_concepts_to_add = set(local_concepts_now).difference(set(self.local_concepts))
self.local_concepts.update(local_concepts_now)

if self.concept_list is not None:
if local_concepts_to_add:
self.concept_list.extend(list(local_concepts_to_add))
return self.concept_list
return self.concept_list
else:
try:
models = self.hf_api.list_models(filter=ModelFilter(model_name='sd-concepts-library/'))
self.concept_list = [a.id.split('/')[1] for a in models]
# when init, add all in dir. when not init, add only concepts added between init and now
self.concept_list.extend(list(local_concepts_to_add))
except Exception as e:
print(f' ** WARNING: Hugging Face textual inversion concepts libraries could not be loaded. The error was {str(e)}.')
print(' ** You may load .bin and .pt file(s) manually using the --embedding_directory argument.')
return self.concept_list
try:
models = self.hf_api.list_models(filter=ModelFilter(model_name='sd-concepts-library/'))
self.concept_list = [a.id.split('/')[1] for a in models]
except Exception as e:
print(f' ** WARNING: Hugging Face textual inversion concepts libraries could not be loaded. The error was {str(e)}.')
print(' ** You may load .bin and .pt file(s) manually using the --embedding_directory argument.')
return self.concept_list

def get_concept_model_path(self, concept_name:str)->str:
'''
Expand All @@ -58,6 +70,12 @@ def concept_to_trigger(self, concept_name:str)->str:
'''
if concept_name in self.triggers:
return self.triggers[concept_name]
elif self.concept_is_local(concept_name):
trigger = f'<{concept_name}>'
self.triggers[concept_name] = trigger
self.concept_names[trigger] = concept_name
return trigger

file = self.get_concept_file(concept_name, 'token_identifier.txt', local_only=True)
if not file:
return None
Expand Down Expand Up @@ -115,10 +133,20 @@ def do_replace(match)->str:
return self.match_concept.sub(do_replace, prompt)

def get_concept_file(self, concept_name:str, file_name:str='learned_embeds.bin' , local_only:bool=False)->str:
if not self.concept_is_downloaded(concept_name) and not local_only:
if not (self.concept_is_downloaded(concept_name) or self.concept_is_local(concept_name) or local_only):
self.download_concept(concept_name)
path = os.path.join(self._concept_path(concept_name), file_name)

# get local path in invokeai/embeddings if local concept
if self.concept_is_local(concept_name):
concept_path = self._concept_local_path(concept_name)
path = concept_path
else:
concept_path = self._concept_path(concept_name)
path = os.path.join(concept_path, file_name)
return path if os.path.exists(path) else None

def concept_is_local(self, concept_name)->bool:
return concept_name in self.local_concepts

def concept_is_downloaded(self, concept_name)->bool:
concept_directory = self._concept_path(concept_name)
Expand Down Expand Up @@ -167,3 +195,16 @@ def _concept_id(self, concept_name:str)->str:

def _concept_path(self, concept_name:str)->str:
return os.path.join(self.root,'models','sd-concepts-library',concept_name)

def _concept_local_path(self, concept_name:str)->str:
filename = self.local_concepts[concept_name]
return os.path.join(self.root,'embeddings',filename)

def get_local_concepts(self, loc_dir:str):
locs_dic = dict()
if os.path.isdir(loc_dir):
for file in os.listdir(loc_dir):
f = os.path.splitext(file)
if f[1] == '.bin' or f[1] == '.pt':
locs_dic[f[0]] = file
return locs_dic
10 changes: 7 additions & 3 deletions ldm/invoke/readline.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def complete(self, text, state):
elif re.search('(-S\s*|--seed[=\s])\d*$',buffer):
self.matches= self._seed_completions(text,state)

# looking for an embedding concept
elif re.search('<[\w-]*$',buffer):
self.matches= self._concept_completions(text,state)

Expand Down Expand Up @@ -272,12 +273,15 @@ def _seed_completions(self, text, state):
def add_embedding_terms(self, terms:list[str]):
self.embedding_terms = set(terms)
if self.concepts:
self.embedding_terms.update(self.concepts)
self.embedding_terms.update(set(self.concepts.list_concepts()))

def _concept_completions(self, text, state):
if self.concepts is None:
self.concepts = set(Concepts().list_concepts())
self.embedding_terms.update(self.concepts)
# cache Concepts() instance so we can check for updates in concepts_list during runtime.
self.concepts = Concepts()
self.embedding_terms.update(set(self.concepts.list_concepts()))
else:
self.embedding_terms.update(set(self.concepts.list_concepts()))

partial = text[1:] # this removes the leading '<'
if len(partial) == 0:
Expand Down

0 comments on commit 21bf512

Please sign in to comment.