Skip to content

Commit

Permalink
move token logging into conditioning.py
Browse files Browse the repository at this point in the history
  • Loading branch information
lstein committed Sep 4, 2022
1 parent 68e8997 commit 89a7622
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 32 deletions.
28 changes: 25 additions & 3 deletions ldm/dream/conditioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import torch

class Conditioning():
def __init__(self, model, logger=None):
def __init__(self, model, log_tokens=False):
self.model = model
self.logger = logger if logger else lambda : None # right way to make a noop?
self.logger = self.log_tokenization if log_tokens else lambda a : None # right way to make a noop?

def get_uc_and_c(self, prompt, skip_normalize=False):
uc = self.model.get_learned_conditioning([''])
Expand Down Expand Up @@ -66,6 +66,28 @@ def split_weighted_subprompts(self, text, skip_normalize=False)->list:
equal_weight = 1 / len(parsed_prompts)
return [(x[0], equal_weight) for x in parsed_prompts]
return [(x[0], x[1] / weight_sum) for x in parsed_prompts]


# shows how the prompt is tokenized
# usually tokens have '</w>' to indicate end-of-word,
# but for readability it has been replaced with ' '
def log_tokenization(self, text):
tokens = self.model.cond_stage_model.tokenizer._tokenize(text)
tokenized = ""
discarded = ""
usedTokens = 0
totalTokens = len(tokens)
for i in range(0, totalTokens):
token = tokens[i].replace('</w>', ' ')
# alternate color
s = (usedTokens % 6) + 1
if i < self.model.cond_stage_model.max_length:
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
usedTokens += 1
else: # over max token length
discarded = discarded + f"\x1b[0;3{s};40m{token}"
print(f"\n>> Tokens ({usedTokens}):\n{tokenized}\x1b[0m")
if discarded != "":
print(
f">> Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m")


1 change: 1 addition & 0 deletions ldm/dream/generator/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,4 @@ def slerp(self, t, v0, v1, DOT_THRESHOLD=0.9995):
v2 = torch.from_numpy(v2).to(self.model.device)

return v2

31 changes: 2 additions & 29 deletions ldm/simplet2i.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def process_image(image,seed):
init_image = None

try:
uc, c = Conditioning(self.model,self._log_tokenization).get_uc_and_c(prompt, skip_normalize)
uc, c = Conditioning(self.model,self.log_tokenization).get_uc_and_c(prompt, skip_normalize)

if init_img:
assert os.path.exists(init_img), f'>> {init_img}: File not found'
Expand Down Expand Up @@ -467,16 +467,14 @@ def _set_sampler(self):
def _load_model_from_config(self, config, ckpt):
print(f'>> Loading model from {ckpt}')
pl_sd = torch.load(ckpt, map_location='cpu')
# if "global_step" in pl_sd:
# print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd['state_dict']
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
model.to(self.device)
model.eval()
if self.full_precision:
print(
'Using slower but more accurate full-precision math (--full_precision)'
'>> Using slower but more accurate full-precision math (--full_precision)'
)
else:
print(
Expand Down Expand Up @@ -535,31 +533,6 @@ def _fit_image(self,image,max_dimensions):
)
return image

# shows how the prompt is tokenized
# usually tokens have '</w>' to indicate end-of-word,
# but for readability it has been replaced with ' '
def _log_tokenization(self, text):
if not self.log_tokenization:
return
tokens = self.model.cond_stage_model.tokenizer._tokenize(text)
tokenized = ""
discarded = ""
usedTokens = 0
totalTokens = len(tokens)
for i in range(0, totalTokens):
token = tokens[i].replace('</w>', ' ')
# alternate color
s = (usedTokens % 6) + 1
if i < self.model.cond_stage_model.max_length:
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
usedTokens += 1
else: # over max token length
discarded = discarded + f"\x1b[0;3{s};40m{token}"
print(f"\nTokens ({usedTokens}):\n{tokenized}\x1b[0m")
if discarded != "":
print(
f"Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m")

def _resolution_check(self, width, height, log=False):
resize_needed = False
w, h = map(
Expand Down

0 comments on commit 89a7622

Please sign in to comment.