Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add generate to coca model #314

Merged
merged 10 commits into from Dec 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
83 changes: 73 additions & 10 deletions src/open_clip/coca_model.py
Expand Up @@ -14,7 +14,7 @@
AttentionalPooler,
)
from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower

from .generation_utils import top_a, top_k, top_p

@dataclass
class MultimodalCfg(CLIPTextCfg):
Expand Down Expand Up @@ -91,6 +91,7 @@ def __init__(
self.ln_final = text.ln_final
self.text_projection = text.text_projection
self.register_buffer("attn_mask", text.attn_mask, persistent=False)
self.context_length = self.positional_embedding.shape[0] - 1

self.cls_token = nn.Parameter(torch.randn(embed_dim))
self.visual = _build_vision_tower(
Expand Down Expand Up @@ -159,21 +160,25 @@ def _build_cls_mask(self, text, cast_dtype):
def encode_text(self, text, normalize=True, return_tokens=False):
text = text[:, :-1] # make space for CLS token
cast_dtype = self.transformer.get_cast_dtype()

attn_mask = self.attn_mask[None, :].expand(
text.shape[0] * self.heads, *self.attn_mask.shape
seq_len = text.shape[1]
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
x = torch.cat(
[
x + self.positional_embedding[:seq_len, :].to(cast_dtype),
self._repeat(self.cls_token + self.positional_embedding[-1, :], x.shape[0])
],
dim=1
)
seq_len += 1 # seq is 1 longer as we added CLS
attn_mask = self.attn_mask[None, :seq_len, :seq_len].expand(
text.shape[0] * self.heads, seq_len, seq_len
)
cls_mask = self._build_cls_mask(text, cast_dtype)

x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
x = torch.cat([x, self._repeat(self.cls_token, x.shape[0])], dim=1)
x = x + self.positional_embedding.to(cast_dtype)

x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x, attn_mask=attn_mask + cls_mask)
x = x.permute(1, 0, 2) # LND -> NLD

# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), :] @ self.text_projection

cls_emb = x[torch.arange(x.shape[0]), -1]
Expand All @@ -195,3 +200,61 @@ def forward(self, image, text):
logits = self.to_logits(text_tokens)

return image_latents, text_latents, logits, labels, self.logit_scale.exp()

def generate(
self,
image,
text,
seq_len,
max_seq_len=None,
mask_prob = 0.0,
temperature = 1.,
filter_logits_fn = top_k,
filter_thres = 0.9,
min_p_pow = 2.0,
min_p_ratio = 0.02,
):

assert mask_prob < 1, "mask_prob must be smaller than 1."

if max_seq_len is None:
max_seq_len = self.context_length

was_training = self.training
num_dims = len(text.shape)

if num_dims == 1:
text = text[None, :]

_, t = text.shape
self.eval()
out = text

for _ in range(seq_len):
x = out[:, -max_seq_len:]

# TODO: adjust for dict output
logits = self(image, x)[2][:, -1]

if filter_logits_fn in {top_k, top_p}:
filtered_logits = filter_logits_fn(logits, thres=filter_thres)
probs = F.softmax(filtered_logits / temperature, dim=-1)

elif filter_logits_fn is top_a:
filtered_logits = filter_logits_fn(
logits, min_p_pow=min_p_pow, min_p_ratio=min_p_ratio
)
probs = F.softmax(filtered_logits / temperature, dim=-1)

sample = torch.multinomial(probs, 1)

out = torch.cat((out, sample), dim=-1)


out = out[:, t:]

if num_dims == 1:
out = out.squeeze(0)

self.train(was_training)
return out
38 changes: 38 additions & 0 deletions src/open_clip/generation_utils.py
@@ -0,0 +1,38 @@
from math import ceil
import torch
from torch import nn
import torch.nn.functional as F

def exists(val):
return val is not None

# nucleus

def top_p(logits, thres = 0.9):
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

sorted_indices_to_remove = cum_probs > (1 - thres)
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = 0

sorted_logits[sorted_indices_to_remove] = float('-inf')
return sorted_logits.scatter(1, sorted_indices, sorted_logits)

# topk

def top_k(logits, thres = 0.9):
k = ceil((1 - thres) * logits.shape[-1])
val, ind = torch.topk(logits, k)
probs = torch.full_like(logits, float('-inf'))
probs.scatter_(1, ind, val)
return probs

# top_a

def top_a(logits, min_p_pow=2.0, min_p_ratio=0.02):
probs = F.softmax(logits, dim=-1)
limit = torch.pow(torch.max(probs), min_p_pow) * min_p_ratio
logits[probs < limit] = float('-inf')
logits[probs >= limit] = 1
return logits
13 changes: 7 additions & 6 deletions src/open_clip/transformer.py
Expand Up @@ -627,17 +627,18 @@ def build_attention_mask(self):
return mask

def forward(self, image_embs, text_embs):
text_embs = text_embs.permute(1, 0, 2) # NLD -> LND
text_embs = text_embs.permute(1, 0, 2) # NLD -> LNDsq
image_embs = image_embs.permute(1, 0, 2) # NLD -> LND
seq_len = text_embs.shape[0]

for r, ca in zip(self.resblocks, self.cross_attn):
for resblock, cross_attn in zip(self.resblocks, self.cross_attn):
if self.grad_checkpointing and not torch.jit.is_scripting():
# TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
text_embs = checkpoint(r, text_embs, None, None, self.attn_mask)
text_embs = checkpoint(ca, text_embs, image_embs, image_embs, None)
text_embs = checkpoint(resblock, text_embs, None, None, self.attn_mask[:seq_len, :seq_len])
text_embs = checkpoint(cross_attn, text_embs, image_embs, image_embs, None)
else:
text_embs = r(text_embs, attn_mask=self.attn_mask)
text_embs = ca(text_embs, k_x=image_embs, v_x=image_embs)
text_embs = resblock(text_embs, attn_mask=self.attn_mask[:seq_len, :seq_len])
text_embs = cross_attn(text_embs, k_x=image_embs, v_x=image_embs)

x = text_embs.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x)
Expand Down