Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add mammut #641

Open
wants to merge 53 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
8d2589f
initial commits
gpucce May 4, 2023
4e21849
typos
gpucce May 9, 2023
2c967a9
first run
May 20, 2023
53cf4ed
t arg change
May 23, 2023
46a89a4
Merge remote-tracking branch 'upstream/main' into add_mammut
May 23, 2023
5827a10
better generation logic
gpucce May 23, 2023
fcee7c2
Merge remote-tracking branch 'upstream/main' into add_mammut
gpucce May 30, 2023
997b786
adjust notation
gpucce May 30, 2023
852b2f4
make as with cls at the end
gpucce Jun 24, 2023
88f70ba
transformer.py: MultimodalTransformer not using init_parameters
Jun 26, 2023
ad716d7
Merge branch 'fix_multimodal_transformer' of https://github.com/iejMa…
gpucce Jun 27, 2023
dcb18d7
test CI
gpucce Jun 27, 2023
491bb5c
move back ci
gpucce Jun 27, 2023
dce72e8
split pooling
gpucce Jun 28, 2023
623226c
Merge remote-tracking branch 'origin/double_att_pool' into fix_multim…
gpucce Jun 28, 2023
ee982c9
Merge branch 'fix_multimodal_transformer' of https://github.com/iejMa…
gpucce Jun 29, 2023
e191018
Merge branch 'fix_multimodal_transformer' into add_mammut
gpucce Jun 29, 2023
20dc72b
Merge branch 'main' into add_mammut
gpucce Jun 29, 2023
4b95dd4
fix padding and 0 loss
gpucce Jul 4, 2023
1c38622
Merge branch 'main' into add_mammut
gpucce Jul 4, 2023
826e799
revert multiple attn pooler
gpucce Jul 4, 2023
b25d6f9
rm white space
gpucce Jul 4, 2023
5a55161
duplicated init_parameters
gpucce Jul 4, 2023
26f39c0
missing is_decoder
gpucce Jul 5, 2023
fe6bc7b
small improvements
gpucce Jul 10, 2023
0ea74ac
move mage projection
gpucce Jul 19, 2023
e4258f1
remove useless kwargs
gpucce Jul 19, 2023
ddb647c
remove useless kwargs
gpucce Jul 19, 2023
a3d32fd
refactor mammut into MultimodalTransformer
gpucce Jul 19, 2023
5a29467
fix args
gpucce Jul 19, 2023
805ad2d
small improvements
gpucce Jul 19, 2023
46da15f
inherit init
gpucce Jul 19, 2023
53c6d39
make equal
gpucce Jul 19, 2023
1d46cfa
fix typo
gpucce Jul 19, 2023
2abdf12
output latents and logits
gpucce Jul 19, 2023
4c8dcee
add mammut L/14 config
gpucce Jul 31, 2023
2ca74c1
allow text=None in forward
gpucce Jul 31, 2023
52e7b3e
better decoder
gpucce Aug 8, 2023
4f0b832
better decoder
gpucce Aug 14, 2023
b980b57
Merge branch 'main' into add_mammut
gpucce Sep 9, 2023
dc8d841
Merge branch 'main' into add_mammut
gpucce Sep 19, 2023
f13b230
integrate transformers generate changes
gpucce Oct 21, 2023
29a052f
Merge branch 'main' into merge_main_mammut
gpucce Oct 26, 2023
e83ac00
stash pop
gpucce Oct 26, 2023
717fe7a
update generation
gpucce Oct 26, 2023
f64c995
uniform generation
gpucce Oct 30, 2023
f3912dc
make equal
gpucce Oct 30, 2023
c4bb198
add context legnth to mammut
gpucce Oct 30, 2023
ef9071b
add older poolign
gpucce Oct 30, 2023
cb619dc
Merge branch 'merge_main_mammut' into add_mammut
gpucce Oct 30, 2023
1e14e4c
Merge branch 'main' into add_mammut
gpucce Oct 30, 2023
b34393e
typo and print
gpucce Nov 1, 2023
cc50367
uniform everything
gpucce Nov 1, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
342 changes: 15 additions & 327 deletions src/open_clip/coca_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,32 +13,7 @@
MultimodalTransformer,
)
from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower

try:
from transformers import (
BeamSearchScorer,
LogitsProcessorList,
TopPLogitsWarper,
TopKLogitsWarper,
RepetitionPenaltyLogitsProcessor,
MinLengthLogitsProcessor,
MaxLengthCriteria,
StoppingCriteriaList
)

GENERATION_TYPES = {
"top_k": TopKLogitsWarper,
"top_p": TopPLogitsWarper,
"beam_search": "beam_search"
}
_has_transformers = True
except ImportError as e:
GENERATION_TYPES = {
"top_k": None,
"top_p": None,
"beam_search": "beam_search"
}
_has_transformers = False
from .generation_utils import Generator


@dataclass
Expand All @@ -48,13 +23,18 @@ class MultimodalCfg(CLIPTextCfg):
heads: int = 8
n_queries: int = 256
attn_pooler_heads: int = 8
cross_attn_ratio: int = 1
does_full_decoding: bool = False
output_tokens: bool = False
has_mlp: bool = True


def _build_text_decoder_tower(
embed_dim,
multimodal_cfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
is_decoder=True,
):
multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
act_layer = QuickGELU if quick_gelu else nn.GELU
Expand All @@ -68,15 +48,18 @@ def _build_text_decoder_tower(
heads=multimodal_cfg.heads,
layers=multimodal_cfg.layers,
ls_init_value=multimodal_cfg.ls_init_value,
cross_attn_ratio=multimodal_cfg.cross_attn_ratio,
has_mlp=multimodal_cfg.has_mlp,
output_dim=embed_dim,
output_tokens=multimodal_cfg.output_tokens,
act_layer=act_layer,
norm_layer=norm_layer,
)

return decoder


class CoCa(nn.Module):
class CoCa(nn.Module, Generator):
def __init__(
self,
embed_dim,
Expand Down Expand Up @@ -148,13 +131,8 @@ def encode_text(self, text, normalize: bool = True):
text_latent, _ = self._encode_text(text, normalize=normalize)
return text_latent

def forward(
self,
image,
text: Optional[torch.Tensor] = None,
image_latent: Optional[torch.Tensor] = None,
image_embs: Optional[torch.Tensor] = None,
):
def forward(self, image=None, text=None, image_latent=None, image_embs=None, is_training=True):
text_latent, token_embs = self._encode_text(text)
if image_latent is None or image_embs is None:
image_latent, image_embs = self._encode_image(image)

Expand All @@ -164,7 +142,9 @@ def forward(
text_latent, token_embs = self._encode_text(text)

# TODO: add assertion to avoid bugs?
labels = text[:, -token_embs.shape[1]:]
labels = text[:, 1:]
if is_training:
token_embs = token_embs[:, :-1]

logits = self.text_decoder(image_embs, token_embs)
return {
Expand All @@ -175,295 +155,3 @@ def forward(
"logit_scale": self.logit_scale.exp()
}

def generate(
self,
image,
text=None,
seq_len=30,
max_seq_len=77,
temperature=1.,
generation_type="beam_search",
top_p=0.1, # keep tokens in the 1 - top_p quantile
top_k=1, # keeps the top_k most probable tokens
pad_token_id=None,
eos_token_id=None,
sot_token_id=None,
num_beams=6,
num_beam_groups=3,
min_seq_len=5,
stopping_criteria=None,
repetition_penalty=1.0,
fixed_output_length=False # if True output.shape == (batch_size, seq_len)
):
# taking many ideas and components from HuggingFace GenerationMixin
# https://huggingface.co/docs/transformers/main/en/main_classes/text_generation
assert _has_transformers, "Please install transformers for generate functionality. `pip install transformers`."
assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len"

with torch.no_grad():
sot_token_id = 49406 if sot_token_id is None else sot_token_id
eos_token_id = 49407 if eos_token_id is None else eos_token_id
pad_token_id = self.pad_id if pad_token_id is None else pad_token_id
logit_processor = LogitsProcessorList(
[
MinLengthLogitsProcessor(min_seq_len, eos_token_id),
RepetitionPenaltyLogitsProcessor(repetition_penalty),
]
)

if stopping_criteria is None:
stopping_criteria = [MaxLengthCriteria(max_length=seq_len)]

stopping_criteria = StoppingCriteriaList(
stopping_criteria
)

device = image.device

if generation_type == "beam_search":
output = self._generate_beamsearch(
image_inputs=image,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
sot_token_id=sot_token_id,
num_beams=num_beams,
num_beam_groups=num_beam_groups,
min_seq_len=min_seq_len,
stopping_criteria=stopping_criteria,
logit_processor=logit_processor,
)
if fixed_output_length and output.shape[1] < seq_len:
return torch.cat(
(output, torch.ones(output.shape[0], seq_len-output.shape[1], device=device, dtype=output.dtype) * self.pad_id),
dim=1
)
return output

elif generation_type == "top_p":
logit_warper = GENERATION_TYPES[generation_type](top_p)
elif generation_type == "top_k":
logit_warper = GENERATION_TYPES[generation_type](top_k)
else:
raise ValueError(
f"generation_type has to be one of "
f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}."
)

image_latent, image_embs = self._encode_image(image)

if text is None:
text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id

was_training = self.training
num_dims = len(text.shape)

if num_dims == 1:
text = text[None, :]

cur_len = text.shape[1]
self.eval()
out = text

while True:
x = out[:, -max_seq_len:]
cur_len = x.shape[1]
logits = self(image, x, image_latent=image_latent, image_embs=image_embs)["logits"][:, -1]
mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id)
sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id

if mask.all():
if not fixed_output_length:
break
else:
logits = logits[~mask, :]
filtered_logits = logit_processor(x[~mask, :], logits)
filtered_logits = logit_warper(x[~mask, :], filtered_logits)
probs = F.softmax(filtered_logits / temperature, dim=-1)

if (cur_len + 1 == seq_len):
sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id
else:
sample[~mask, :] = torch.multinomial(probs, 1)

out = torch.cat((out, sample), dim=-1)

cur_len += 1

if stopping_criteria(out, None):
break

if num_dims == 1:
out = out.squeeze(0)

self.train(was_training)
return out

def _generate_beamsearch(
self,
image_inputs,
pad_token_id=None,
eos_token_id=None,
sot_token_id=None,
num_beams=6,
num_beam_groups=3,
min_seq_len=5,
stopping_criteria=None,
logit_processor=None,
logit_warper=None,
):
device = image_inputs.device
batch_size = image_inputs.shape[0]
image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0)
image_latent, image_embs = self._encode_image(image_inputs)

input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long)
input_ids = input_ids * sot_token_id
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_beams=num_beams,
device=device,
num_beam_groups=num_beam_groups,
)
# instantiate logits processors
logits_processor = (
LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)])
if logit_processor is None
else logit_processor
)

num_beams = beam_scorer.num_beams
num_beam_groups = beam_scorer.num_beam_groups
num_sub_beams = num_beams // num_beam_groups
batch_size = len(beam_scorer._beam_hyps) // num_beam_groups
batch_beam_size, cur_len = input_ids.shape
beam_indices = None

if num_beams * batch_size != batch_beam_size:
raise ValueError(
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
)

beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
# initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
# the same group don't produce same tokens everytime.
beam_scores[:, ::num_sub_beams] = 0
beam_scores = beam_scores.view((batch_size * num_beams,))

while True:

# predicted tokens in cur_len step
current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)

# indices which will form the beams in the next time step
reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)

# do one decoder step on all beams of all sentences in batch
model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs)
outputs = self(
model_inputs['images'],
model_inputs['text'],
image_latent=image_latent,
image_embs=image_embs
)

for beam_group_idx in range(num_beam_groups):
group_start_idx = beam_group_idx * num_sub_beams
group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
group_size = group_end_idx - group_start_idx

# indices of beams of current group among all sentences in batch
batch_group_indices = []

for batch_idx in range(batch_size):
batch_group_indices.extend(
[batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]
)
group_input_ids = input_ids[batch_group_indices]

# select outputs of beams of currentg group only
next_token_logits = outputs['logits'][batch_group_indices, -1, :]
vocab_size = next_token_logits.shape[-1]

next_token_scores_processed = logits_processor(
group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx
)
next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)
next_token_scores = next_token_scores.expand_as(next_token_scores_processed)

# reshape for beam search
next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)

next_token_scores, next_tokens = torch.topk(
next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True
)

next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
next_tokens = next_tokens % vocab_size

# stateless
process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
beam_outputs = beam_scorer.process(
group_input_ids,
next_token_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
beam_indices=process_beam_indices,
group_index=beam_group_idx,
)
beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
beam_next_tokens = beam_outputs["next_beam_tokens"]
beam_idx = beam_outputs["next_beam_indices"]

input_ids[batch_group_indices] = group_input_ids[beam_idx]
group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
current_tokens[batch_group_indices] = group_input_ids[:, -1]

# (beam_idx // group_size) -> batch_idx
# (beam_idx % group_size) -> offset of idx inside the group
reordering_indices[batch_group_indices] = (
num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size)
)

input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)

# increase cur_len
cur_len = cur_len + 1
if beam_scorer.is_done or stopping_criteria(input_ids, None):
break

final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
sequence_outputs = beam_scorer.finalize(
input_ids,
beam_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
max_length=stopping_criteria.max_length,
beam_indices=final_beam_indices,
)
return sequence_outputs['sequences']


def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs):
if past:
input_ids = input_ids[:, -1].unsqueeze(-1)

attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)

if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
else:
position_ids = None
return {
"text": input_ids,
"images": image_inputs,
"past_key_values": past,
"position_ids": position_ids,
"attention_mask": attention_mask,
}
Loading