Skip to content

Commit

Permalink
✨Supported custom_embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
carefree0910 committed Jun 12, 2023
1 parent 03fbb00 commit eea5e24
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 0 deletions.
29 changes: 29 additions & 0 deletions examples/carefree_creator/app.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import os
import json

from PIL import Image
from typing import Any
from typing import Dict
Expand All @@ -19,6 +22,8 @@


TDataModel = TypeVar("TDataModel", bound="BaseModel")
text_keys = ["text", "prompt", "negative_prompt"]
CUSTOM_EMBEDDINGS_PATH = os.environ.get("CFDRAW_CFCREATOR_CUSTOM_EMBEDDING_PATH")


def inject(
Expand Down Expand Up @@ -57,6 +62,30 @@ def inject(
data.extraData["lora_paths"] = lora_paths
if lora_scales:
data.extraData["lora_scales"] = lora_scales
# custom embeddings
if CUSTOM_EMBEDDINGS_PATH is not None:
custom_embeddings_folder = Path(CUSTOM_EMBEDDINGS_PATH)
else:
custom_embeddings_folder = Path(__file__).parent / "custom_embeddings"
if custom_embeddings_folder.is_dir():
custom_embedding_paths = [
p
for p in custom_embeddings_folder.iterdir()
if p.is_file() and p.name.endswith(".ce")
]
custom_embeddings = {}
for key in text_keys:
k_text = data.extraData.get(key)
if k_text is None:
continue
for ce_path in custom_embedding_paths:
stem = ce_path.stem
if stem in k_text and stem not in custom_embeddings:
with open(ce_path, "r") as f:
ce = json.load(f)
custom_embeddings[stem] = ce
if custom_embeddings:
data.extraData["custom_embeddings"] = custom_embeddings
# collect
kw = shallow_copy_dict(data.extraData)
if extra is not None:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
1. embeddings should be json file
2. file names should be `*.ce`

0 comments on commit eea5e24

Please sign in to comment.