Skip to content

Commit

Permalink
Tweaks to coca fix, rename is_training to output_labels
Browse files Browse the repository at this point in the history
  • Loading branch information
rwightman committed May 9, 2024
1 parent 7079e7e commit 65f460e
Showing 1 changed file with 20 additions and 10 deletions.
30 changes: 20 additions & 10 deletions src/open_clip/coca_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def forward(
text: Optional[torch.Tensor] = None,
image_latent: Optional[torch.Tensor] = None,
image_embs: Optional[torch.Tensor] = None,
is_training=True
output_labels: bool = True,
):
if image_latent is None or image_embs is None:
image_latent, image_embs = self._encode_image(image)
Expand All @@ -170,19 +170,21 @@ def forward(

text_latent, token_embs = self._encode_text(text)

# TODO: add assertion to avoid bugs?
labels = text[:, 1:]
if is_training:
# FIXME this isn't an ideal solution, would like to improve -RW
labels: Optional[torch.Tensor] = text[:, 1:] if output_labels else None
if output_labels:
# align text_embs and thus logits with labels for teacher-forcing caption loss
token_embs = token_embs[:, :-1]

logits = self.text_decoder(image_embs, token_embs)
out_dict = {
"image_features": image_latent,
"text_features": text_latent,
"logits": logits,
"labels": labels,
"logit_scale": self.logit_scale.exp()
}
if labels is not None:
out_dict["labels"] = labels
if self.logit_bias is not None:
out_dict["logit_bias"] = self.logit_bias
return out_dict
Expand Down Expand Up @@ -245,8 +247,11 @@ def generate(
logit_processor=logit_processor,
)
if fixed_output_length and output.shape[1] < seq_len:
return torch.cat(
(output, torch.ones(output.shape[0], seq_len-output.shape[1], device=device, dtype=output.dtype) * self.pad_id),
pad_len = seq_len - output.shape[1]
return torch.cat((
output,
torch.ones(output.shape[0], pad_len, device=device, dtype=output.dtype) * self.pad_id
),
dim=1
)
return output
Expand All @@ -272,14 +277,19 @@ def generate(
if num_dims == 1:
text = text[None, :]

cur_len = text.shape[1]
self.eval()
out = text

while True:
x = out[:, -max_seq_len:]
cur_len = x.shape[1]
logits = self(image, x, image_latent=image_latent, image_embs=image_embs, is_training=False)["logits"][:, -1]
logits = self(
image,
x,
image_latent=image_latent,
image_embs=image_embs,
output_labels=False,
)["logits"][:, -1]
mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id)
sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id

Expand Down Expand Up @@ -376,7 +386,7 @@ def _generate_beamsearch(
model_inputs['text'],
image_latent=image_latent,
image_embs=image_embs,
is_training=False
output_labels=False,
)

for beam_group_idx in range(num_beam_groups):
Expand Down

0 comments on commit 65f460e

Please sign in to comment.