Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add --log_tokenization to sysargs #2523

Merged
merged 8 commits into from
Feb 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 5 additions & 8 deletions invokeai/backend/invoke_ai_web_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,9 +626,10 @@ def handle_generate_image_event(
printable_parameters["init_mask"][:64] + "..."
)

print(
f">> Image generation requested: {printable_parameters}\nESRGAN parameters: {esrgan_parameters}\nFacetool parameters: {facetool_parameters}"
)
print(f'\n>> Image Generation Parameters:\n\n{printable_parameters}\n')
print(f'>> ESRGAN Parameters: {esrgan_parameters}')
print(f'>> Facetool Parameters: {facetool_parameters}')

self.generate_images(
generation_parameters,
esrgan_parameters,
Expand Down Expand Up @@ -1154,7 +1155,7 @@ def image_done(image, seed, first_seed, attention_maps_image=None):
image, os.path.basename(path), self.thumbnail_image_path
)

print(f'>> Image generated: "{path}"')
print(f'\n\n>> Image generated: "{path}"\n')
self.write_log_message(f'[Generated] "{path}": {command}')

if progress.total_iterations > progress.current_iteration:
Expand Down Expand Up @@ -1193,8 +1194,6 @@ def image_done(image, seed, first_seed, attention_maps_image=None):

progress.set_current_iteration(progress.current_iteration + 1)

print(generation_parameters)

def diffusers_step_callback_adapter(*cb_args, **kwargs):
if isinstance(cb_args[0], PipelineIntermediateState):
progress_state: PipelineIntermediateState = cb_args[0]
Expand Down Expand Up @@ -1305,8 +1304,6 @@ def parameters_to_generated_image_metadata(self, parameters):

rfc_dict["variations"] = variations

print(parameters)

if rfc_dict["type"] == "img2img":
rfc_dict["strength"] = parameters["strength"]
rfc_dict["fit"] = parameters["fit"] # TODO: Noncompliant
Expand Down
2 changes: 1 addition & 1 deletion ldm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ def process_image(image,seed):
print('>> Could not generate image.')

toc = time.time()
print('>> Usage stats:')
print('\n>> Usage stats:')
print(
f'>> {len(results)} image(s) generated in', '%4.2fs' % (
toc - tic)
Expand Down
9 changes: 8 additions & 1 deletion ldm/invoke/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def parse_args(self):
elif os.path.exists(legacyinit):
print(f'>> WARNING: Old initialization file found at {legacyinit}. This location is deprecated. Please move it to {Globals.root}/invokeai.init.')
sysargs.insert(0,f'@{legacyinit}')
Globals.log_tokenization = self._arg_parser.parse_args(sysargs).log_tokenization

self._arg_switches = self._arg_parser.parse_args(sysargs)
return self._arg_switches
Expand Down Expand Up @@ -599,6 +600,12 @@ def _create_arg_parser(self):
help=f'Set the default sampler. Supported samplers: {", ".join(SAMPLER_CHOICES)}',
default='k_lms',
)
render_group.add_argument(
'--log_tokenization',
'-t',
action='store_true',
help='shows how the prompt is split into tokens'
)
render_group.add_argument(
'-f',
'--strength',
Expand Down Expand Up @@ -744,7 +751,7 @@ def _create_dream_cmd_parser(self):
invoke> !fetch 0000015.8929913.png
invoke> a fantastic alien landscape -W 576 -H 512 -s 60 -A plms -C 7.5
invoke> !fetch /path/to/images/*.png prompts.txt

!replay /path/to/prompts.txt
Replays all the prompts contained in the file prompts.txt.

Expand Down
22 changes: 13 additions & 9 deletions ldm/invoke/conditioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder
from ..modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToEmbeddingsConverter
from ldm.invoke.globals import Globals


def get_uc_and_c_and_ec(prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False):
Expand Down Expand Up @@ -92,9 +93,9 @@ def _get_conditioning_for_prompt(parsed_prompt: Union[Blend, FlattenedPrompt], p
Process prompt structure and tokens, and return (conditioning, unconditioning, extra_conditioning_info)
"""

if log_tokens:
print(f">> Parsed prompt to {parsed_prompt}")
print(f">> Parsed negative prompt to {parsed_negative_prompt}")
if log_tokens or Globals.log_tokenization:
print(f"\n>> [TOKENLOG] Parsed Prompt: {parsed_prompt}")
print(f"\n>> [TOKENLOG] Parsed Negative Prompt: {parsed_negative_prompt}")

conditioning = None
cac_args: cross_attention_control.Arguments = None
Expand Down Expand Up @@ -235,7 +236,7 @@ def _get_embeddings_and_tokens_for_prompt(model, flattened_prompt: FlattenedProm
fragments = [x.text for x in flattened_prompt.children]
weights = [x.weight for x in flattened_prompt.children]
embeddings, tokens = model.get_learned_conditioning([fragments], return_tokens=True, fragment_weights=[weights])
if log_tokens:
if log_tokens or Globals.log_tokenization:
text = " ".join(fragments)
log_tokenization(text, model, display_label=log_display_label)

Expand Down Expand Up @@ -273,12 +274,12 @@ def log_tokenization(text, model, display_label=None):
# usually tokens have '</w>' to indicate end-of-word,
# but for readability it has been replaced with ' '
"""

tokens = model.cond_stage_model.tokenizer.tokenize(text)
tokenized = ""
discarded = ""
usedTokens = 0
totalTokens = len(tokens)

for i in range(0, totalTokens):
token = tokens[i].replace('</w>', ' ')
# alternate color
Expand All @@ -288,8 +289,11 @@ def log_tokenization(text, model, display_label=None):
usedTokens += 1
else: # over max token length
discarded = discarded + f"\x1b[0;3{s};40m{token}"
print(f"\n>> Tokens {display_label or ''} ({usedTokens}):\n{tokenized}\x1b[0m")

if usedTokens > 0:
print(f'\n>> [TOKENLOG] Tokens {display_label or ""} ({usedTokens}):')
print(f'{tokenized}\x1b[0m')

if discarded != "":
print(
f">> Tokens Discarded ({totalTokens - usedTokens}):\n{discarded}\x1b[0m"
)
print(f'\n>> [TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):')
print(f'{discarded}\x1b[0m')
2 changes: 0 additions & 2 deletions ldm/invoke/generator/diffusers_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import inspect
import secrets
import sys
import warnings
from dataclasses import dataclass, field
from typing import List, Optional, Union, Callable, Type, TypeVar, Generic, Any

Expand Down Expand Up @@ -641,7 +640,6 @@ def get_learned_conditioning(self, c: List[List[str]], *, return_tokens=True, fr

@property
def cond_stage_model(self):
warnings.warn("legacy compatibility layer", DeprecationWarning)
return self.prompt_fragments_to_embeddings_converter

@torch.inference_mode()
Expand Down