-
Notifications
You must be signed in to change notification settings - Fork 4.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor to make it easier to add custom conds to models.
- Loading branch information
1 parent
3fce888
commit 036f88c
Showing
4 changed files
with
170 additions
and
173 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import enum | ||
import torch | ||
import math | ||
import comfy.utils | ||
|
||
|
||
def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9) | ||
return abs(a*b) // math.gcd(a, b) | ||
|
||
class CONDRegular: | ||
def __init__(self, cond): | ||
self.cond = cond | ||
|
||
def _copy_with(self, cond): | ||
return self.__class__(cond) | ||
|
||
def process_cond(self, batch_size, device, **kwargs): | ||
return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size).to(device)) | ||
|
||
def can_concat(self, other): | ||
if self.cond.shape != other.cond.shape: | ||
return False | ||
return True | ||
|
||
def concat(self, others): | ||
conds = [self.cond] | ||
for x in others: | ||
conds.append(x.cond) | ||
return torch.cat(conds) | ||
|
||
class CONDNoiseShape(CONDRegular): | ||
def process_cond(self, batch_size, device, area, **kwargs): | ||
data = self.cond[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] | ||
return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size).to(device)) | ||
|
||
|
||
class CONDCrossAttn(CONDRegular): | ||
def can_concat(self, other): | ||
s1 = self.cond.shape | ||
s2 = other.cond.shape | ||
if s1 != s2: | ||
if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen | ||
return False | ||
|
||
mult_min = lcm(s1[1], s2[1]) | ||
diff = mult_min // min(s1[1], s2[1]) | ||
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much | ||
return False | ||
return True | ||
|
||
def concat(self, others): | ||
conds = [self.cond] | ||
crossattn_max_len = self.cond.shape[1] | ||
for x in others: | ||
c = x.cond | ||
crossattn_max_len = lcm(crossattn_max_len, c.shape[1]) | ||
conds.append(c) | ||
|
||
out = [] | ||
for c in conds: | ||
if c.shape[1] < crossattn_max_len: | ||
c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result | ||
out.append(c) | ||
return torch.cat(out) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.