diff --git a/scripts/negpip.py b/scripts/negpip.py index 4c9ebd8..30e27ea 100644 --- a/scripts/negpip.py +++ b/scripts/negpip.py @@ -42,6 +42,8 @@ def __init__(self): self.unlen = [] self.contokens = [] self.untokens = [] + self.hr = False + self.x = None self.ipa = None @@ -88,6 +90,8 @@ def f_toggle(is_img2img): def process_batch(self, p, active,**kwargs): self.__init__() + flag = False + if getattr(shared.opts,OPT_HIDE, False) and not getattr(shared.opts,OPT_ACT, False): return elif not active: return @@ -98,14 +102,14 @@ def process_batch(self, p, active,**kwargs): if "rp.py" in script.filename: self.rpscript = script + self.hrp, self.hrn = hr_dealer(p) + self.active = active self.batch = p.batch_size self.isxl = hasattr(shared.sd_model,"conditioner") self.rev = p.sampler_name in ["DDIM", "PLMS", "UniPC"] - scheduled_p = prompt_parser.get_learned_conditioning_prompt_schedules(p.prompts,p.steps) - scheduled_np = prompt_parser.get_learned_conditioning_prompt_schedules(p.negative_prompts,p.steps) tokenizer = shared.sd_model.conditioner.embedders[0].tokenize_line if self.isxl else shared.sd_model.cond_stage_model.tokenize_line @@ -114,6 +118,7 @@ def process_batch(self, p, active,**kwargs): def getshedulednegs(scheduled,prompts): output = [] + nonlocal flag for i, batch_shedule in enumerate(scheduled): stepout = [] seps = None @@ -141,6 +146,7 @@ def getshedulednegs(scheduled,prompts): if text == "BREAK": continue if weight < 0: textweights.append([text,weight]) + flag = True padtextweight.append([padd,textweights]) tokens, tokensnum = tokenizer(sep_prompt) padd = tokensnum // 75 + 1 + padd @@ -148,11 +154,17 @@ def getshedulednegs(scheduled,prompts): output.append(stepout) return output + scheduled_p = prompt_parser.get_learned_conditioning_prompt_schedules(p.prompts,p.steps) + scheduled_np = prompt_parser.get_learned_conditioning_prompt_schedules(p.negative_prompts,p.steps) + + if self.hrp: scheduled_hr_p = prompt_parser.get_learned_conditioning_prompt_schedules(p.hr_prompts,p.hr_second_pass_steps if p.hr_second_pass_steps > 0 else p.step) + if self.hrn: scheduled_hr_np = prompt_parser.get_learned_conditioning_prompt_schedules(p.hr_negative_prompts,p.hr_second_pass_steps if p.hr_second_pass_steps > 0 else p.step) + nip = getshedulednegs(scheduled_p,p.prompts) pin = getshedulednegs(scheduled_np,p.negative_prompts) - p.hr_prompts = p.prompts - p.hr_negative_prompts = p.negative_prompts + if self.hrp: hr_nip = getshedulednegs(scheduled_hr_p,p.hr_prompts) + if self.hrn: hr_pin = getshedulednegs(scheduled_hr_np,p.hr_negative_prompts) def conddealer(targets): conds =[] @@ -197,6 +209,9 @@ def calcconds(targetlist): self.conds_all = calcconds(nip) self.unconds_all = calcconds(pin) + if self.hrp: self.hr_conds_all = calcconds(hr_nip) + if self.hrn: self.hr_unconds_all = calcconds(hr_pin) + #print(self.conds_all) #print(self.unconds_all) @@ -208,6 +223,11 @@ def calcsets(A, B): self.conlen = calcsets(tokenizer(p.prompts[0])[1],75) self.unlen = calcsets(tokenizer(p.negative_prompts[0])[1],75) + if not flag: + self.active = False + unload(self,p) + return + if not hasattr(self,"negpip_dr_callbacks"): self.negpip_dr_callbacks = on_cfg_denoiser(self.denoiser_callback) @@ -227,20 +247,21 @@ def calcsets(A, B): }) def postprocess(self, p, processed, *args): - if hasattr(self,"handle"): - hook_forwards(self, p.sd_model.model.diffusion_model, remove=True) - del self.handle + unload(self,p) self.conds_all = None self.unconds_all = None def denoiser_callback(self, params: CFGDenoiserParams): if debug: print(params.text_cond.shape) if self.active: + if self.x is None: self.x = params.x.shape + if self.x != params.x.shape: self.hr = True + self.latenti = 0 condslist = [] tokenslist = [] - for step, regions in self.conds_all[0]: + for step, regions in self.hr_conds_all[0] if self.hrp and self.hr else self.conds_all[0]: if step >= params.sampling_step + 2: for region, conds, tokens in regions: condslist.append(conds) @@ -252,7 +273,7 @@ def denoiser_callback(self, params: CFGDenoiserParams): uncondslist = [] untokenslist = [] - for step, regions in self.unconds_all[0]: + for step, regions in self.hr_unconds_all[0] if self.hrn and self.hr else self.unconds_all[0]: if step >= params.sampling_step + 2: for region, unconds, untokens in regions: uncondslist.append(unconds) @@ -264,58 +285,82 @@ def denoiser_callback(self, params: CFGDenoiserParams): from pprint import pprint +def unload(self,p): + if hasattr(self,"handle"): + hook_forwards(self, p.sd_model.model.diffusion_model, remove=True) + del self.handle + def hook_forward(self, module): + def forward(x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): - if debug: print("x.shape:",x.shape,"context.shape:",context.shape,"self.contokens",self.contokens,"self.untokens",self.untokens) + if debug: print(" x.shape:",x.shape,"context.shape:",context.shape,"self.contokens",self.contokens,"self.untokens",self.untokens) + def sub_forward(x, context, mask, additional_tokens, n_times_crossframe_attn_in_self,conds,contokens,unconds,untokens, latent = None): - if debug: print("x.shape[0]:",x.shape[0],"batch:",self.batch *2) + if debug: print(" x.shape[0]:",x.shape[0],"batch:",self.batch *2) + if x.shape[0] == self.batch *2: - if debug: print("x.shape[0] == self.batch *2") + if debug: print(" x.shape[0] == self.batch *2") + if self.rev: contn,contp = context.chunk(2) ixn,ixp = x.chunk(2) else: contp,contn = context.chunk(2) - ixp,ixn = x.chunk(2)#x[0:self.batch,:,:],x[self.batch:,:,:] + ixp,ixn = x.chunk(2) #x[0:self.batch,:,:],x[self.batch:,:,:] - if conds is not None:contp = torch.cat((contp,conds),1) - if unconds is not None:contn = torch.cat((contn,unconds),1) + if conds is not None: + if contp.shape[0] != conds.shape[0]: + conds = conds.expand(contp.shape[0],-1,-1) + contp = torch.cat((contp,conds),1) + if unconds is not None: + if contn.shape[0] != unconds.shape[0]: + unconds = unconds.expand(contn.shape[0],-1,-1) + contn = torch.cat((contn,unconds),1) xp = main_foward(self, module, ixp,contp,mask,additional_tokens,n_times_crossframe_attn_in_self,contokens) xn = main_foward(self, module, ixn,contn,mask,additional_tokens,n_times_crossframe_attn_in_self,untokens) - #print(conds,contokens,unconds,untokens) out = torch.cat([xn,xp]) if self.rev else torch.cat([xp,xn]) return out elif latent is not None: - if debug:print("latent is not None") + if debug:print(" latent is not None") if latent: conds = conds if conds is not None else None else: conds = unconds if unconds is not None else None - - if conds is not None:context = torch.cat([context,conds],1) + if conds is not None: + if context.shape[0] != conds.shape[0]: + conds = conds.expand(context.shape[0],-1,-1) + context = torch.cat([context,conds],1) + tokens = contokens if contokens is not None else untokens out = main_foward(self, module, x,context,mask,additional_tokens,n_times_crossframe_attn_in_self,tokens) return out else: - #print("else") - if debug: print(context.shape[1] , self.conlen,self.unlen) + if debug: + print(" Else") + print(context.shape[1] , self.conlen,self.unlen) + tokens = [] concon = counter(self.isxl) if debug: print(concon) if context.shape[1] == self.conlen * 77 and concon: if conds is not None: + if context.shape[0] != conds.shape[0]: + conds = conds.expand(context.shape[0],-1,-1) context = torch.cat([context,conds],1) tokens = contokens elif context.shape[1] == self.unlen * 77 and concon: if unconds is not None: + if context.shape[0] != unconds.shape[0]: + unconds = unconds.expand(context.shape[0],-1,-1) context = torch.cat([context,unconds],1) tokens = untokens out = main_foward(self, module, x,context,mask,additional_tokens,n_times_crossframe_attn_in_self,tokens) return out + if self.enable_rp_latent: if len(self.conds) - 1 >= self.latenti: out = sub_forward(x, context, mask, additional_tokens, n_times_crossframe_attn_in_self,self.conds[self.latenti],self.contokens[self.latenti],None,None ,latent = True) @@ -325,11 +370,11 @@ def sub_forward(x, context, mask, additional_tokens, n_times_crossframe_attn_in_ self.latenti = 0 return out else: - #print(self.conds,self.unconds) if self.conds is not None and self.unconds is not None and len(self.conds) > 0 and len(self.unconds) > 0: return sub_forward(x, context, mask, additional_tokens, n_times_crossframe_attn_in_self,self.conds[0],self.contokens[0],self.unconds[0],self.untokens[0]) else: return sub_forward(x, context, mask, additional_tokens, n_times_crossframe_attn_in_self,None,None,None,None) + return forward count = 0 @@ -420,3 +465,11 @@ def ext_on_ui_settings(): shared.opts.add_option(cur_setting_name, shared.OptionInfo(*option_info, section=section)) on_ui_settings(ext_on_ui_settings) + +def hr_dealer(p): + if not hasattr(p, "hr_prompts"): + p.hr_prompts = None + if not hasattr(p, "hr_negative_prompts"): + p.hr_negative_prompts = None + + return bool(p.hr_prompts), bool(p.hr_negative_prompts )