Skip to content

Commit

Permalink
add generate to coca model (#314)
Browse files Browse the repository at this point in the history
* add initial generative support

* make generation context_length independend

* remove kwargs

* last positional embeddings for CLS

* typo

* fix mask len

* add comment

* remove unused args

* simpler logic for input shorter than context length

Co-authored-by: gpucce <g.puccetti@gmail.com>
  • Loading branch information
gpucce and gpucce committed Dec 22, 2022
1 parent 279e088 commit dee1ea5
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 16 deletions.
83 changes: 73 additions & 10 deletions src/open_clip/coca_model.py
Expand Up @@ -14,7 +14,7 @@
AttentionalPooler,
)
from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower

from .generation_utils import top_a, top_k, top_p

@dataclass
class MultimodalCfg(CLIPTextCfg):
Expand Down Expand Up @@ -91,6 +91,7 @@ def __init__(
self.ln_final = text.ln_final
self.text_projection = text.text_projection
self.register_buffer("attn_mask", text.attn_mask, persistent=False)
self.context_length = self.positional_embedding.shape[0] - 1

self.cls_token = nn.Parameter(torch.randn(embed_dim))
self.visual = _build_vision_tower(
Expand Down Expand Up @@ -159,21 +160,25 @@ def _build_cls_mask(self, text, cast_dtype):
def encode_text(self, text, normalize=True, return_tokens=False):
text = text[:, :-1] # make space for CLS token
cast_dtype = self.transformer.get_cast_dtype()

attn_mask = self.attn_mask[None, :].expand(
text.shape[0] * self.heads, *self.attn_mask.shape
seq_len = text.shape[1]
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
x = torch.cat(
[
x + self.positional_embedding[:seq_len, :].to(cast_dtype),
self._repeat(self.cls_token + self.positional_embedding[-1, :], x.shape[0])
],
dim=1
)
seq_len += 1 # seq is 1 longer as we added CLS
attn_mask = self.attn_mask[None, :seq_len, :seq_len].expand(
text.shape[0] * self.heads, seq_len, seq_len
)
cls_mask = self._build_cls_mask(text, cast_dtype)

x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
x = torch.cat([x, self._repeat(self.cls_token, x.shape[0])], dim=1)
x = x + self.positional_embedding.to(cast_dtype)

x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x, attn_mask=attn_mask + cls_mask)
x = x.permute(1, 0, 2) # LND -> NLD

# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), :] @ self.text_projection

cls_emb = x[torch.arange(x.shape[0]), -1]
Expand All @@ -195,3 +200,61 @@ def forward(self, image, text):
logits = self.to_logits(text_tokens)

return image_latents, text_latents, logits, labels, self.logit_scale.exp()

def generate(
self,
image,
text,
seq_len,
max_seq_len=None,
mask_prob = 0.0,
temperature = 1.,
filter_logits_fn = top_k,
filter_thres = 0.9,
min_p_pow = 2.0,
min_p_ratio = 0.02,
):

assert mask_prob < 1, "mask_prob must be smaller than 1."

if max_seq_len is None:
max_seq_len = self.context_length

was_training = self.training
num_dims = len(text.shape)

if num_dims == 1:
text = text[None, :]

_, t = text.shape
self.eval()
out = text

for _ in range(seq_len):
x = out[:, -max_seq_len:]

# TODO: adjust for dict output
logits = self(image, x)[2][:, -1]

if filter_logits_fn in {top_k, top_p}:
filtered_logits = filter_logits_fn(logits, thres=filter_thres)
probs = F.softmax(filtered_logits / temperature, dim=-1)

elif filter_logits_fn is top_a:
filtered_logits = filter_logits_fn(
logits, min_p_pow=min_p_pow, min_p_ratio=min_p_ratio
)
probs = F.softmax(filtered_logits / temperature, dim=-1)

sample = torch.multinomial(probs, 1)

out = torch.cat((out, sample), dim=-1)


out = out[:, t:]

if num_dims == 1:
out = out.squeeze(0)

self.train(was_training)
return out
38 changes: 38 additions & 0 deletions src/open_clip/generation_utils.py
@@ -0,0 +1,38 @@
from math import ceil
import torch
from torch import nn
import torch.nn.functional as F

def exists(val):
return val is not None

# nucleus

def top_p(logits, thres = 0.9):
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

sorted_indices_to_remove = cum_probs > (1 - thres)
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = 0

sorted_logits[sorted_indices_to_remove] = float('-inf')
return sorted_logits.scatter(1, sorted_indices, sorted_logits)

# topk

def top_k(logits, thres = 0.9):
k = ceil((1 - thres) * logits.shape[-1])
val, ind = torch.topk(logits, k)
probs = torch.full_like(logits, float('-inf'))
probs.scatter_(1, ind, val)
return probs

# top_a

def top_a(logits, min_p_pow=2.0, min_p_ratio=0.02):
probs = F.softmax(logits, dim=-1)
limit = torch.pow(torch.max(probs), min_p_pow) * min_p_ratio
logits[probs < limit] = float('-inf')
logits[probs >= limit] = 1
return logits
13 changes: 7 additions & 6 deletions src/open_clip/transformer.py
Expand Up @@ -627,17 +627,18 @@ def build_attention_mask(self):
return mask

def forward(self, image_embs, text_embs):
text_embs = text_embs.permute(1, 0, 2) # NLD -> LND
text_embs = text_embs.permute(1, 0, 2) # NLD -> LNDsq
image_embs = image_embs.permute(1, 0, 2) # NLD -> LND
seq_len = text_embs.shape[0]

for r, ca in zip(self.resblocks, self.cross_attn):
for resblock, cross_attn in zip(self.resblocks, self.cross_attn):
if self.grad_checkpointing and not torch.jit.is_scripting():
# TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
text_embs = checkpoint(r, text_embs, None, None, self.attn_mask)
text_embs = checkpoint(ca, text_embs, image_embs, image_embs, None)
text_embs = checkpoint(resblock, text_embs, None, None, self.attn_mask[:seq_len, :seq_len])
text_embs = checkpoint(cross_attn, text_embs, image_embs, image_embs, None)
else:
text_embs = r(text_embs, attn_mask=self.attn_mask)
text_embs = ca(text_embs, k_x=image_embs, v_x=image_embs)
text_embs = resblock(text_embs, attn_mask=self.attn_mask[:seq_len, :seq_len])
text_embs = cross_attn(text_embs, k_x=image_embs, v_x=image_embs)

x = text_embs.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x)
Expand Down

0 comments on commit dee1ea5

Please sign in to comment.