Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
hako-mikan committed Dec 17, 2023
1 parent d85898a commit d4630ab
Showing 1 changed file with 75 additions and 22 deletions.
97 changes: 75 additions & 22 deletions scripts/negpip.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ def __init__(self):
self.unlen = []
self.contokens = []
self.untokens = []
self.hr = False
self.x = None

self.ipa = None

Expand Down Expand Up @@ -88,6 +90,8 @@ def f_toggle(is_img2img):

def process_batch(self, p, active,**kwargs):
self.__init__()
flag = False

if getattr(shared.opts,OPT_HIDE, False) and not getattr(shared.opts,OPT_ACT, False): return
elif not active: return

Expand All @@ -98,14 +102,14 @@ def process_batch(self, p, active,**kwargs):
if "rp.py" in script.filename:
self.rpscript = script

self.hrp, self.hrn = hr_dealer(p)

self.active = active
self.batch = p.batch_size
self.isxl = hasattr(shared.sd_model,"conditioner")

self.rev = p.sampler_name in ["DDIM", "PLMS", "UniPC"]

scheduled_p = prompt_parser.get_learned_conditioning_prompt_schedules(p.prompts,p.steps)
scheduled_np = prompt_parser.get_learned_conditioning_prompt_schedules(p.negative_prompts,p.steps)

tokenizer = shared.sd_model.conditioner.embedders[0].tokenize_line if self.isxl else shared.sd_model.cond_stage_model.tokenize_line

Expand All @@ -114,6 +118,7 @@ def process_batch(self, p, active,**kwargs):

def getshedulednegs(scheduled,prompts):
output = []
nonlocal flag
for i, batch_shedule in enumerate(scheduled):
stepout = []
seps = None
Expand Down Expand Up @@ -141,18 +146,25 @@ def getshedulednegs(scheduled,prompts):
if text == "BREAK": continue
if weight < 0:
textweights.append([text,weight])
flag = True
padtextweight.append([padd,textweights])
tokens, tokensnum = tokenizer(sep_prompt)
padd = tokensnum // 75 + 1 + padd
stepout.append([step,padtextweight])
output.append(stepout)
return output

scheduled_p = prompt_parser.get_learned_conditioning_prompt_schedules(p.prompts,p.steps)
scheduled_np = prompt_parser.get_learned_conditioning_prompt_schedules(p.negative_prompts,p.steps)

if self.hrp: scheduled_hr_p = prompt_parser.get_learned_conditioning_prompt_schedules(p.hr_prompts,p.hr_second_pass_steps if p.hr_second_pass_steps > 0 else p.step)
if self.hrn: scheduled_hr_np = prompt_parser.get_learned_conditioning_prompt_schedules(p.hr_negative_prompts,p.hr_second_pass_steps if p.hr_second_pass_steps > 0 else p.step)

nip = getshedulednegs(scheduled_p,p.prompts)
pin = getshedulednegs(scheduled_np,p.negative_prompts)

p.hr_prompts = p.prompts
p.hr_negative_prompts = p.negative_prompts
if self.hrp: hr_nip = getshedulednegs(scheduled_hr_p,p.hr_prompts)
if self.hrn: hr_pin = getshedulednegs(scheduled_hr_np,p.hr_negative_prompts)

def conddealer(targets):
conds =[]
Expand Down Expand Up @@ -197,6 +209,9 @@ def calcconds(targetlist):
self.conds_all = calcconds(nip)
self.unconds_all = calcconds(pin)

if self.hrp: self.hr_conds_all = calcconds(hr_nip)
if self.hrn: self.hr_unconds_all = calcconds(hr_pin)

#print(self.conds_all)
#print(self.unconds_all)

Expand All @@ -208,6 +223,11 @@ def calcsets(A, B):
self.conlen = calcsets(tokenizer(p.prompts[0])[1],75)
self.unlen = calcsets(tokenizer(p.negative_prompts[0])[1],75)

if not flag:
self.active = False
unload(self,p)
return

if not hasattr(self,"negpip_dr_callbacks"):
self.negpip_dr_callbacks = on_cfg_denoiser(self.denoiser_callback)

Expand All @@ -227,20 +247,21 @@ def calcsets(A, B):
})

def postprocess(self, p, processed, *args):
if hasattr(self,"handle"):
hook_forwards(self, p.sd_model.model.diffusion_model, remove=True)
del self.handle
unload(self,p)
self.conds_all = None
self.unconds_all = None

def denoiser_callback(self, params: CFGDenoiserParams):
if debug: print(params.text_cond.shape)
if self.active:
if self.x is None: self.x = params.x.shape
if self.x != params.x.shape: self.hr = True

self.latenti = 0

condslist = []
tokenslist = []
for step, regions in self.conds_all[0]:
for step, regions in self.hr_conds_all[0] if self.hrp and self.hr else self.conds_all[0]:
if step >= params.sampling_step + 2:
for region, conds, tokens in regions:
condslist.append(conds)
Expand All @@ -252,7 +273,7 @@ def denoiser_callback(self, params: CFGDenoiserParams):

uncondslist = []
untokenslist = []
for step, regions in self.unconds_all[0]:
for step, regions in self.hr_unconds_all[0] if self.hrn and self.hr else self.unconds_all[0]:
if step >= params.sampling_step + 2:
for region, unconds, untokens in regions:
uncondslist.append(unconds)
Expand All @@ -264,58 +285,82 @@ def denoiser_callback(self, params: CFGDenoiserParams):

from pprint import pprint

def unload(self,p):
if hasattr(self,"handle"):
hook_forwards(self, p.sd_model.model.diffusion_model, remove=True)
del self.handle

def hook_forward(self, module):

def forward(x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0):
if debug: print("x.shape:",x.shape,"context.shape:",context.shape,"self.contokens",self.contokens,"self.untokens",self.untokens)
if debug: print(" x.shape:",x.shape,"context.shape:",context.shape,"self.contokens",self.contokens,"self.untokens",self.untokens)

def sub_forward(x, context, mask, additional_tokens, n_times_crossframe_attn_in_self,conds,contokens,unconds,untokens, latent = None):
if debug: print("x.shape[0]:",x.shape[0],"batch:",self.batch *2)
if debug: print(" x.shape[0]:",x.shape[0],"batch:",self.batch *2)

if x.shape[0] == self.batch *2:
if debug: print("x.shape[0] == self.batch *2")
if debug: print(" x.shape[0] == self.batch *2")

if self.rev:
contn,contp = context.chunk(2)
ixn,ixp = x.chunk(2)
else:
contp,contn = context.chunk(2)
ixp,ixn = x.chunk(2)#x[0:self.batch,:,:],x[self.batch:,:,:]
ixp,ixn = x.chunk(2) #x[0:self.batch,:,:],x[self.batch:,:,:]

if conds is not None:contp = torch.cat((contp,conds),1)
if unconds is not None:contn = torch.cat((contn,unconds),1)
if conds is not None:
if contp.shape[0] != conds.shape[0]:
conds = conds.expand(contp.shape[0],-1,-1)
contp = torch.cat((contp,conds),1)
if unconds is not None:
if contn.shape[0] != unconds.shape[0]:
unconds = unconds.expand(contn.shape[0],-1,-1)
contn = torch.cat((contn,unconds),1)
xp = main_foward(self, module, ixp,contp,mask,additional_tokens,n_times_crossframe_attn_in_self,contokens)
xn = main_foward(self, module, ixn,contn,mask,additional_tokens,n_times_crossframe_attn_in_self,untokens)
#print(conds,contokens,unconds,untokens)

out = torch.cat([xn,xp]) if self.rev else torch.cat([xp,xn])
return out

elif latent is not None:
if debug:print("latent is not None")
if debug:print(" latent is not None")
if latent:
conds = conds if conds is not None else None
else:
conds = unconds if unconds is not None else None

if conds is not None:context = torch.cat([context,conds],1)
if conds is not None:
if context.shape[0] != conds.shape[0]:
conds = conds.expand(context.shape[0],-1,-1)
context = torch.cat([context,conds],1)

tokens = contokens if contokens is not None else untokens

out = main_foward(self, module, x,context,mask,additional_tokens,n_times_crossframe_attn_in_self,tokens)
return out

else:
#print("else")
if debug: print(context.shape[1] , self.conlen,self.unlen)
if debug:
print(" Else")
print(context.shape[1] , self.conlen,self.unlen)

tokens = []
concon = counter(self.isxl)
if debug: print(concon)
if context.shape[1] == self.conlen * 77 and concon:
if conds is not None:
if context.shape[0] != conds.shape[0]:
conds = conds.expand(context.shape[0],-1,-1)
context = torch.cat([context,conds],1)
tokens = contokens
elif context.shape[1] == self.unlen * 77 and concon:
if unconds is not None:
if context.shape[0] != unconds.shape[0]:
unconds = unconds.expand(context.shape[0],-1,-1)
context = torch.cat([context,unconds],1)
tokens = untokens
out = main_foward(self, module, x,context,mask,additional_tokens,n_times_crossframe_attn_in_self,tokens)
return out

if self.enable_rp_latent:
if len(self.conds) - 1 >= self.latenti:
out = sub_forward(x, context, mask, additional_tokens, n_times_crossframe_attn_in_self,self.conds[self.latenti],self.contokens[self.latenti],None,None ,latent = True)
Expand All @@ -325,11 +370,11 @@ def sub_forward(x, context, mask, additional_tokens, n_times_crossframe_attn_in_
self.latenti = 0
return out
else:
#print(self.conds,self.unconds)
if self.conds is not None and self.unconds is not None and len(self.conds) > 0 and len(self.unconds) > 0:
return sub_forward(x, context, mask, additional_tokens, n_times_crossframe_attn_in_self,self.conds[0],self.contokens[0],self.unconds[0],self.untokens[0])
else:
return sub_forward(x, context, mask, additional_tokens, n_times_crossframe_attn_in_self,None,None,None,None)

return forward

count = 0
Expand Down Expand Up @@ -420,3 +465,11 @@ def ext_on_ui_settings():
shared.opts.add_option(cur_setting_name, shared.OptionInfo(*option_info, section=section))

on_ui_settings(ext_on_ui_settings)

def hr_dealer(p):
if not hasattr(p, "hr_prompts"):
p.hr_prompts = None
if not hasattr(p, "hr_negative_prompts"):
p.hr_negative_prompts = None

return bool(p.hr_prompts), bool(p.hr_negative_prompts )

0 comments on commit d4630ab

Please sign in to comment.