Skip to content

Commit

Permalink
Refactor to make it easier to add custom conds to models.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Oct 25, 2023
1 parent 3fce888 commit 036f88c
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 173 deletions.
64 changes: 64 additions & 0 deletions comfy/conds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import enum
import torch
import math
import comfy.utils


def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
return abs(a*b) // math.gcd(a, b)

class CONDRegular:
def __init__(self, cond):
self.cond = cond

def _copy_with(self, cond):
return self.__class__(cond)

def process_cond(self, batch_size, device, **kwargs):
return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size).to(device))

def can_concat(self, other):
if self.cond.shape != other.cond.shape:
return False
return True

def concat(self, others):
conds = [self.cond]
for x in others:
conds.append(x.cond)
return torch.cat(conds)

class CONDNoiseShape(CONDRegular):
def process_cond(self, batch_size, device, area, **kwargs):
data = self.cond[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size).to(device))


class CONDCrossAttn(CONDRegular):
def can_concat(self, other):
s1 = self.cond.shape
s2 = other.cond.shape
if s1 != s2:
if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen
return False

mult_min = lcm(s1[1], s2[1])
diff = mult_min // min(s1[1], s2[1])
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
return False
return True

def concat(self, others):
conds = [self.cond]
crossattn_max_len = self.cond.shape[1]
for x in others:
c = x.cond
crossattn_max_len = lcm(crossattn_max_len, c.shape[1])
conds.append(c)

out = []
for c in conds:
if c.shape[1] < crossattn_max_len:
c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result
out.append(c)
return torch.cat(out)
14 changes: 10 additions & 4 deletions comfy/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
import comfy.model_management
import comfy.conds
import numpy as np
from enum import Enum
from . import utils
Expand Down Expand Up @@ -49,7 +50,7 @@ def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=
self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))

def apply_model(self, x, t, c_concat=None, c_crossattn=None, c_adm=None, control=None, transformer_options={}):
def apply_model(self, x, t, c_concat=None, c_crossattn=None, c_adm=None, control=None, transformer_options={}, **kwargs):
if c_concat is not None:
xc = torch.cat([x] + [c_concat], dim=1)
else:
Expand All @@ -72,7 +73,8 @@ def is_adm(self):
def encode_adm(self, **kwargs):
return None

def cond_concat(self, **kwargs):
def extra_conds(self, **kwargs):
out = {}
if self.inpaint_model:
concat_keys = ("mask", "masked_image")
cond_concat = []
Expand Down Expand Up @@ -101,8 +103,12 @@ def blank_inpaint_image_like(latent_image):
cond_concat.append(torch.ones_like(noise)[:,:1])
elif ck == "masked_image":
cond_concat.append(blank_inpaint_image_like(noise))
return cond_concat
return None
data = torch.cat(cond_concat, dim=1)
out['c_concat'] = comfy.conds.CONDNoiseShape(data)
adm = self.encode_adm(**kwargs)
if adm is not None:
out['c_adm'] = comfy.conds.CONDRegular(adm)
return out

def load_model_weights(self, sd, unet_prefix=""):
to_load = {}
Expand Down
31 changes: 17 additions & 14 deletions comfy/sample.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
import comfy.model_management
import comfy.samplers
import comfy.conds
import comfy.utils
import math
import numpy as np
Expand Down Expand Up @@ -33,22 +34,24 @@ def prepare_mask(noise_mask, shape, device):
noise_mask = noise_mask.to(device)
return noise_mask

def broadcast_cond(cond, batch, device):
"""broadcasts conditioning to the batch size"""
copy = []
for p in cond:
t = comfy.utils.repeat_to_batch_size(p[0], batch)
t = t.to(device)
copy += [[t] + p[1:]]
return copy

def get_models_from_cond(cond, model_type):
models = []
for c in cond:
if model_type in c[1]:
models += [c[1][model_type]]
if model_type in c:
models += [c[model_type]]
return models

def convert_cond(cond):
out = []
for c in cond:
temp = c[1].copy()
model_conds = temp.get("model_conds", {})
if c[0] is not None:
model_conds["c_crossattn"] = comfy.conds.CONDCrossAttn(c[0])
temp["model_conds"] = model_conds
out.append(temp)
return out

def get_additional_models(positive, negative, dtype):
"""loads additional models in positive and negative conditioning"""
control_nets = set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control"))
Expand All @@ -72,6 +75,8 @@ def cleanup_additional_models(models):

def prepare_sampling(model, noise_shape, positive, negative, noise_mask):
device = model.load_device
positive = convert_cond(positive)
negative = convert_cond(negative)

if noise_mask is not None:
noise_mask = prepare_mask(noise_mask, noise_shape, device)
Expand All @@ -81,9 +86,7 @@ def prepare_sampling(model, noise_shape, positive, negative, noise_mask):
comfy.model_management.load_models_gpu([model] + models, comfy.model_management.batch_area_memory(noise_shape[0] * noise_shape[2] * noise_shape[3]) + inference_memory)
real_model = model.model

positive_copy = broadcast_cond(positive, noise_shape[0], device)
negative_copy = broadcast_cond(negative, noise_shape[0], device)
return real_model, positive_copy, negative_copy, noise_mask, models
return real_model, positive, negative, noise_mask, models


def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None):
Expand Down
Loading

0 comments on commit 036f88c

Please sign in to comment.