Skip to content

Commit

Permalink
Add support for GLIGEN textbox model.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Apr 19, 2023
1 parent 472b1cc commit 3696d16
Show file tree
Hide file tree
Showing 9 changed files with 491 additions and 28 deletions.
343 changes: 343 additions & 0 deletions comfy/gligen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,343 @@
import torch
from torch import nn, einsum
from ldm.modules.attention import CrossAttention
from inspect import isfunction


def exists(val):
return val is not None


def uniq(arr):
return{el: True for el in arr}.keys()


def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d


# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)

def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * torch.nn.functional.gelu(gate)


class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(
nn.Linear(dim, inner_dim),
nn.GELU()
) if not glu else GEGLU(dim, inner_dim)

self.net = nn.Sequential(
project_in,
nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out)
)

def forward(self, x):
return self.net(x)


class GatedCrossAttentionDense(nn.Module):
def __init__(self, query_dim, context_dim, n_heads, d_head):
super().__init__()

self.attn = CrossAttention(
query_dim=query_dim,
context_dim=context_dim,
heads=n_heads,
dim_head=d_head)
self.ff = FeedForward(query_dim, glu=True)

self.norm1 = nn.LayerNorm(query_dim)
self.norm2 = nn.LayerNorm(query_dim)

self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))

# this can be useful: we can externally change magnitude of tanh(alpha)
# for example, when it is set to 0, then the entire model is same as
# original one
self.scale = 1

def forward(self, x, objs):

x = x + self.scale * \
torch.tanh(self.alpha_attn) * self.attn(self.norm1(x), objs, objs)
x = x + self.scale * \
torch.tanh(self.alpha_dense) * self.ff(self.norm2(x))

return x


class GatedSelfAttentionDense(nn.Module):
def __init__(self, query_dim, context_dim, n_heads, d_head):
super().__init__()

# we need a linear projection since we need cat visual feature and obj
# feature
self.linear = nn.Linear(context_dim, query_dim)

self.attn = CrossAttention(
query_dim=query_dim,
context_dim=query_dim,
heads=n_heads,
dim_head=d_head)
self.ff = FeedForward(query_dim, glu=True)

self.norm1 = nn.LayerNorm(query_dim)
self.norm2 = nn.LayerNorm(query_dim)

self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))

# this can be useful: we can externally change magnitude of tanh(alpha)
# for example, when it is set to 0, then the entire model is same as
# original one
self.scale = 1

def forward(self, x, objs):

N_visual = x.shape[1]
objs = self.linear(objs)

x = x + self.scale * torch.tanh(self.alpha_attn) * self.attn(
self.norm1(torch.cat([x, objs], dim=1)))[:, 0:N_visual, :]
x = x + self.scale * \
torch.tanh(self.alpha_dense) * self.ff(self.norm2(x))

return x


class GatedSelfAttentionDense2(nn.Module):
def __init__(self, query_dim, context_dim, n_heads, d_head):
super().__init__()

# we need a linear projection since we need cat visual feature and obj
# feature
self.linear = nn.Linear(context_dim, query_dim)

self.attn = CrossAttention(
query_dim=query_dim, context_dim=query_dim, dim_head=d_head)
self.ff = FeedForward(query_dim, glu=True)

self.norm1 = nn.LayerNorm(query_dim)
self.norm2 = nn.LayerNorm(query_dim)

self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))

# this can be useful: we can externally change magnitude of tanh(alpha)
# for example, when it is set to 0, then the entire model is same as
# original one
self.scale = 1

def forward(self, x, objs):

B, N_visual, _ = x.shape
B, N_ground, _ = objs.shape

objs = self.linear(objs)

# sanity check
size_v = math.sqrt(N_visual)
size_g = math.sqrt(N_ground)
assert int(size_v) == size_v, "Visual tokens must be square rootable"
assert int(size_g) == size_g, "Grounding tokens must be square rootable"
size_v = int(size_v)
size_g = int(size_g)

# select grounding token and resize it to visual token size as residual
out = self.attn(self.norm1(torch.cat([x, objs], dim=1)))[
:, N_visual:, :]
out = out.permute(0, 2, 1).reshape(B, -1, size_g, size_g)
out = torch.nn.functional.interpolate(
out, (size_v, size_v), mode='bicubic')
residual = out.reshape(B, -1, N_visual).permute(0, 2, 1)

# add residual to visual feature
x = x + self.scale * torch.tanh(self.alpha_attn) * residual
x = x + self.scale * \
torch.tanh(self.alpha_dense) * self.ff(self.norm2(x))

return x


class FourierEmbedder():
def __init__(self, num_freqs=64, temperature=100):

self.num_freqs = num_freqs
self.temperature = temperature
self.freq_bands = temperature ** (torch.arange(num_freqs) / num_freqs)

@torch.no_grad()
def __call__(self, x, cat_dim=-1):
"x: arbitrary shape of tensor. dim: cat dim"
out = []
for freq in self.freq_bands:
out.append(torch.sin(freq * x))
out.append(torch.cos(freq * x))
return torch.cat(out, cat_dim)


class PositionNet(nn.Module):
def __init__(self, in_dim, out_dim, fourier_freqs=8):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim

self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs)
self.position_dim = fourier_freqs * 2 * 4 # 2 is sin&cos, 4 is xyxy

self.linears = nn.Sequential(
nn.Linear(self.in_dim + self.position_dim, 512),
nn.SiLU(),
nn.Linear(512, 512),
nn.SiLU(),
nn.Linear(512, out_dim),
)

self.null_positive_feature = torch.nn.Parameter(
torch.zeros([self.in_dim]))
self.null_position_feature = torch.nn.Parameter(
torch.zeros([self.position_dim]))

def forward(self, boxes, masks, positive_embeddings):
B, N, _ = boxes.shape
masks = masks.unsqueeze(-1)

# embedding position (it may includes padding as placeholder)
xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 --> B*N*C

# learnable null embedding
positive_null = self.null_positive_feature.view(1, 1, -1)
xyxy_null = self.null_position_feature.view(1, 1, -1)

# replace padding with learnable null embedding
positive_embeddings = positive_embeddings * \
masks + (1 - masks) * positive_null
xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null

objs = self.linears(
torch.cat([positive_embeddings, xyxy_embedding], dim=-1))
assert objs.shape == torch.Size([B, N, self.out_dim])
return objs


class Gligen(nn.Module):
def __init__(self, modules, position_net, key_dim):
super().__init__()
self.module_list = nn.ModuleList(modules)
self.position_net = position_net
self.key_dim = key_dim
self.max_objs = 30

def _set_position(self, boxes, masks, positive_embeddings):
objs = self.position_net(boxes, masks, positive_embeddings)

def func(key, x):
module = self.module_list[key]
return module(x, objs)
return func

def set_position(self, latent_image_shape, position_params, device):
batch, c, h, w = latent_image_shape
masks = torch.zeros([self.max_objs], device="cpu")
boxes = []
positive_embeddings = []
for p in position_params:
x1 = (p[4]) / w
y1 = (p[3]) / h
x2 = (p[4] + p[2]) / w
y2 = (p[3] + p[1]) / h
masks[len(boxes)] = 1.0
boxes += [torch.tensor((x1, y1, x2, y2)).unsqueeze(0)]
positive_embeddings += [p[0]]
append_boxes = []
append_conds = []
if len(boxes) < self.max_objs:
append_boxes = [torch.zeros(
[self.max_objs - len(boxes), 4], device="cpu")]
append_conds = [torch.zeros(
[self.max_objs - len(boxes), self.key_dim], device="cpu")]

box_out = torch.cat(
boxes + append_boxes).unsqueeze(0).repeat(batch, 1, 1)
masks = masks.unsqueeze(0).repeat(batch, 1)
conds = torch.cat(positive_embeddings +
append_conds).unsqueeze(0).repeat(batch, 1, 1)
return self._set_position(
box_out.to(device),
masks.to(device),
conds.to(device))

def set_empty(self, latent_image_shape, device):
batch, c, h, w = latent_image_shape
masks = torch.zeros([self.max_objs], device="cpu").repeat(batch, 1)
box_out = torch.zeros([self.max_objs, 4],
device="cpu").repeat(batch, 1, 1)
conds = torch.zeros([self.max_objs, self.key_dim],
device="cpu").repeat(batch, 1, 1)
return self._set_position(
box_out.to(device),
masks.to(device),
conds.to(device))

def cleanup(self):
pass

def get_models(self):
return [self]

def load_gligen(sd):
sd_k = sd.keys()
output_list = []
key_dim = 768
for a in ["input_blocks", "middle_block", "output_blocks"]:
for b in range(20):
k_temp = filter(lambda k: "{}.{}.".format(a, b)
in k and ".fuser." in k, sd_k)
k_temp = map(lambda k: (k, k.split(".fuser.")[-1]), k_temp)

n_sd = {}
for k in k_temp:
n_sd[k[1]] = sd[k[0]]
if len(n_sd) > 0:
query_dim = n_sd["linear.weight"].shape[0]
key_dim = n_sd["linear.weight"].shape[1]

if key_dim == 768: # SD1.x
n_heads = 8
d_head = query_dim // n_heads
else:
d_head = 64
n_heads = query_dim // d_head

gated = GatedSelfAttentionDense(
query_dim, key_dim, n_heads, d_head)
gated.load_state_dict(n_sd, strict=False)
output_list.append(gated)

if "position_net.null_positive_feature" in sd_k:
in_dim = sd["position_net.null_positive_feature"].shape[0]
out_dim = sd["position_net.linears.4.weight"].shape[0]

class WeightsLoader(torch.nn.Module):
pass
w = WeightsLoader()
w.position_net = PositionNet(in_dim, out_dim)
w.load_state_dict(sd, strict=False)

gligen = Gligen(output_list, w.position_net, key_dim)
return gligen
16 changes: 16 additions & 0 deletions comfy/ldm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,14 @@ def forward(self, x, context=None, transformer_options={}):
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)

def _forward(self, x, context=None, transformer_options={}):
current_index = None
if "current_index" in transformer_options:
current_index = transformer_options["current_index"]
if "patches" in transformer_options:
transformer_patches = transformer_options["patches"]
else:
transformer_patches = {}

n = self.norm1(x)
if "tomesd" in transformer_options:
m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"])
Expand All @@ -518,11 +526,19 @@ def _forward(self, x, context=None, transformer_options={}):
n = self.attn1(n, context=context if self.disable_self_attn else None)

x += n
if "middle_patch" in transformer_patches:
patch = transformer_patches["middle_patch"]
for p in patch:
x = p(current_index, x)

n = self.norm2(x)
n = self.attn2(n, context=context)

x += n
x = self.ff(self.norm3(x)) + x

if current_index is not None:
transformer_options["current_index"] += 1
return x


Expand Down
2 changes: 2 additions & 0 deletions comfy/ldm/modules/diffusionmodules/openaimodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,8 @@ def forward(self, x, timesteps=None, context=None, y=None, control=None, transfo
:return: an [N x C x ...] Tensor of outputs.
"""
transformer_options["original_shape"] = list(x.shape)
transformer_options["current_index"] = 0

assert (y is not None) == (
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
Expand Down
Loading

0 comments on commit 3696d16

Please sign in to comment.