Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 51 additions & 1 deletion helpers/caching/sdxl_embeds.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,34 @@ def __init__(
self.data_backend.create_directory(self.cache_dir)
self.batch_write_thread = Thread(target=self.batch_write_embeddings)
self.batch_write_thread.start()
# Ensure the cache has the currently-existing file list.
self.discover_all_files()

def debug_log(self, msg: str):
logger.debug(f"{self.rank_info}{msg}")

def create_hash(self, caption):
return f"{hashlib.md5(caption.encode()).hexdigest()}-{self.model_type}"

def discover_all_files(self, directory: str = None):
"""Identify all files in a directory."""
logger.info(f"(id={self.id}) Listing all text embed cache entries")
# This isn't returned, because we merely check if it's stored, or, store it.
(
StateTracker.get_text_cache_files(data_backend_id=self.id)
or StateTracker.set_text_cache_files(
self.data_backend.list_files(
instance_data_root=self.cache_dir,
str_pattern="*.pt",
),
data_backend_id=self.id,
)
)
self.debug_log(" -> done listing all text embed cache entries")
self.debug_log(
f" -> {StateTracker.get_text_cache_files(data_backend_id=self.id)}"
)

def save_to_cache(self, filename, embeddings):
"""Add write requests to the queue instead of writing directly."""
self.write_queue.put((embeddings, filename))
Expand Down Expand Up @@ -221,10 +242,39 @@ def encode_prompt(self, prompt: str, is_validation: bool = False):

def compute_embeddings_for_prompts(
self,
prompts,
all_prompts,
return_concat: bool = True,
is_validation: bool = False,
):
existing_cache_filenames = list(
StateTracker.get_text_cache_files(data_backend_id=self.id).keys()
)
all_cache_filenames = [f"{self.create_hash(p)}.pt" for p in all_prompts]
self.debug_log(f"Existing cache filenames: {existing_cache_filenames}")
self.debug_log(f"All cache filenames: {all_cache_filenames}")
# Check if we have all the files in the cache
if (
not is_validation
and not return_concat
and all([f in existing_cache_filenames for f in all_cache_filenames])
):
logger.debug(f"(id={self.id}) All prompts are cached, ignoring.")
return None
# Reduce prompts down to the list of unncached prompts.
if not return_concat and not is_validation:
prompts = [
p
for p in all_prompts
if f"{self.create_hash(p)}.pt" not in existing_cache_filenames
]
self.debug_log(
f"Reduced count of prompts for processing from {len(all_prompts)} to {len(prompts)}"
)
else:
prompts = all_prompts
logger.info(
f"Beginning caching of text embeds, we have {len(prompts)} prompts to process."
)
if self.model_type == "sdxl":
return self.compute_embeddings_for_sdxl_prompts(
prompts,
Expand Down
7 changes: 4 additions & 3 deletions helpers/data_backend/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,10 +367,11 @@ def configure_multi_databackend(
use_captions=use_captions,
)
if "text" not in args.skip_file_discovery:
logger.info(
f"Pre-computing text embeds / updating cache. We have {len(captions)} captions to process."
logger.debug(
f"Pre-computing text embeds / updating cache. We have {len(captions)} captions to process, though these will be filtered next."
)
init_backend["text_embed_cache"].compute_embeddings_for_sdxl_prompts(
logger.info("Initialise text embed pre-computation.")
init_backend["text_embed_cache"].compute_embeddings_for_prompts(
captions, return_concat=False
)
accelerator.wait_for_everyone()
Expand Down
40 changes: 39 additions & 1 deletion helpers/training/state_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class StateTracker:
## Caches
all_image_files = {}
all_vae_cache_files = {}
all_text_cache_files = {}
all_caption_files = None

## Backend entities for retrieval
Expand All @@ -36,7 +37,11 @@ class StateTracker:

@classmethod
def delete_cache_files(cls):
for cache_name in ["all_image_files", "all_vae_cache_files"]:
for cache_name in [
"all_image_files",
"all_vae_cache_files",
"all_text_cache_files",
]:
cache_path = Path(cls.args.output_dir) / f"{cache_name}.json"
if cache_path.exists():
try:
Expand All @@ -60,6 +65,13 @@ def delete_cache_files(cls):
except:
pass

filelist = Path(cls.args.output_dir).glob("all_text_cache_files_*.json")
for file in filelist:
try:
file.unlink()
except:
pass

@classmethod
def _load_from_disk(cls, cache_name):
cache_path = Path(cls.args.output_dir) / f"{cache_name}.json"
Expand Down Expand Up @@ -248,6 +260,32 @@ def get_vae_cache_files(cls: list, data_backend_id: str):
)
return cls.all_vae_cache_files[data_backend_id]

@classmethod
def set_text_cache_files(cls, raw_file_list: list, data_backend_id: str):
if cls.all_text_cache_files[data_backend_id] is not None:
cls.all_text_cache_files[data_backend_id].clear()
else:
cls.all_text_cache_files[data_backend_id] = {}
for subdirectory_list in raw_file_list:
_, _, files = subdirectory_list
for image in files:
cls.all_text_cache_files[data_backend_id][path.basename(image)] = False
cls._save_to_disk(
"all_text_cache_files_{}".format(data_backend_id),
cls.all_text_cache_files[data_backend_id],
)
logger.debug(
f"set_text_cache_files found {len(cls.all_text_cache_files[data_backend_id])} images."
)

@classmethod
def get_text_cache_files(cls: list, data_backend_id: str):
if data_backend_id not in cls.all_text_cache_files:
cls.all_text_cache_files[data_backend_id] = cls._load_from_disk(
"all_text_cache_files_{}".format(data_backend_id)
)
return cls.all_text_cache_files[data_backend_id]

@classmethod
def set_caption_files(cls, caption_files):
cls.all_caption_files = caption_files
Expand Down