Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add coca trained (#307) #308

Merged
merged 34 commits into from
Jan 29, 2023
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
1b86601
Add coca trained (#307)
rom1504 Dec 20, 2022
29fa332
Add coca to CI
rom1504 Dec 21, 2022
911c737
Add coca to CI pr
rom1504 Dec 21, 2022
b4881bc
simplify encode_iamge (#313)
gpucce Dec 21, 2022
50bc599
Add cls mask (#312)
gpucce Dec 21, 2022
279e088
Ignore pad tokens in captioning loss (#316)
gpucce Dec 22, 2022
dee1ea5
add `generate` to coca model (#314)
gpucce Dec 22, 2022
30a73d4
use `TextEncoder` in coca `encode_image` (#321)
gpucce Jan 6, 2023
f616050
Merge branch 'main' into coca
rom1504 Jan 6, 2023
061482b
Get some basic PEP changes out of the way
rwightman Jan 9, 2023
d0bd09e
Add tests bis (#355)
gpucce Jan 21, 2023
ef80b7b
Merge branch 'main' into coca
rom1504 Jan 21, 2023
2ab47b7
train.py: fix is_clip when doing distributed (#364)
iejMac Jan 21, 2023
c0e5950
add README (#365)
iejMac Jan 22, 2023
9ab881e
Merge branch 'main' into coca
rom1504 Jan 22, 2023
3f5b0fb
remove output_dict argument (#368)
gpucce Jan 22, 2023
de343fb
do same thing for _encode_image (#366)
iejMac Jan 22, 2023
88aa6ce
CoCa/forward: remove unused output_dict param
iejMac Jan 23, 2023
3b66f37
Revert "do same thing for _encode_image (#366)"
gpucce Jan 24, 2023
cdb91dd
refactor
gpucce Jan 24, 2023
58eb5bd
white space
gpucce Jan 24, 2023
cbd66ed
remove extra layer norm
gpucce Jan 24, 2023
bf6ef3e
move to_logits into decoder
gpucce Jan 24, 2023
03dfeab
leave for later
gpucce Jan 24, 2023
15d6223
better torchscript
gpucce Jan 23, 2023
9beb0d4
annotate hf too
gpucce Jan 23, 2023
fde2aee
Add CoCa-ViT-L/14 config (#379)
iejMac Jan 27, 2023
24e454d
Merge branch 'main' into coca
rom1504 Jan 27, 2023
f7c566b
Remove dead LN code, refactor attn_pool conditional for more clarity,…
rwightman Jan 28, 2023
9533575
latent_dim to embed_dim
gpucce Jan 28, 2023
f5e0c5a
remove extra cfg
gpucce Jan 28, 2023
1ba2ab6
A bit more cleanup, keep context_length as context len, 'num_pos' to …
rwightman Jan 28, 2023
f0847fa
CoCa: add B/32 pretrained (#389)
iejMac Jan 29, 2023
ba081d3
remove coca from ci.yml
rom1504 Jan 29, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ name: Continuous integration
on:
push:
branches:
- main
- main
- coca
paths-ignore:
- '**.md'
- 'CITATION.cff'
Expand All @@ -12,7 +13,8 @@ on:
- 'docs/**'
pull_request:
branches:
- main
- main
- coca
paths-ignore:
- '**.md'
- 'CITATION.cff'
Expand Down Expand Up @@ -81,7 +83,7 @@ jobs:
--group ${{ matrix.job }} \
-m regression_test \
tests \
| head -n -2 | grep -Po 'test_inference_with_data\[\K[^]]*(?=])' \
| head -n -2 | grep -Po 'test_inference_with_data\[\K[^]]*(?=-False]|-True])' \
> models_gh_runner.txt
if [ -n "${{ inputs.manual_revision_reference }}" ]; then
REVISION_REFERENCE=${{ inputs.manual_revision_reference }}
Expand Down
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,20 @@ python -m training.main \
--resume /path/to/checkpoints/epoch_K.pt
```

### Training CoCa:
Training [CoCa](https://arxiv.org/abs/2205.01917) models is enabled through specifying a CoCa config using the ```--model``` parameter of the training script. Currently available configs are "coca_base", "coca_ViT-B-32", and "coca_roberta-ViT-B-32" (which uses RoBERTa as the text encoder). CoCa configs are different from CLIP configs because they have an additional "multimodal_cfg" component which specifies parameters for the multimodal text decoder. Here's an example from the coca_ViT-B-32 config:
```json
"multimodal_cfg": {
"context_length": 76,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12,
"latent_dim": 512,
"attn_pooler_heads": 8
}
```

### Training with pre-trained language models as text encoder:

If you wish to use different language models as the text encoder for CLIP you can do so by using one of the Hugging Face model configs in ```src/open_clip/model_configs``` and passing in it's tokenizer as the ```--model``` and ```--hf-tokenizer-name``` parameters respectively. Currently we only support RoBERTa ("test-roberta" config), however adding new models should be trivial. You can also determine how many layers, from the end, to leave unfrozen with the ```--lock-text-unlocked-layers``` parameter. Here's an example command to train CLIP with the RoBERTa LM that has it's last 10 layers unfrozen:
Expand Down
5 changes: 3 additions & 2 deletions src/open_clip/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer
from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss
from .factory import list_models, add_model_config, get_model_config, load_checkpoint
from .loss import ClipLoss
from .loss import ClipLoss, CoCaLoss
from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg,\
convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype
from .coca_model import CoCa
from .openai import load_openai_model, list_openai_models
from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model,\
get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
Expand Down
193 changes: 193 additions & 0 deletions src/open_clip/coca_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
from typing import Optional

import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
from dataclasses import dataclass

from .transformer import (
LayerNormFp32,
LayerNorm,
QuickGELU,
MultimodalTransformer,
)
from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower
from .generation_utils import top_a, top_k, top_p


@dataclass
class MultimodalCfg(CLIPTextCfg):
mlp_ratio: int = 4
dim_head: int = 64
heads: int = 8
n_queries: int = 256
attn_pooler_heads: int = 8


def _build_text_decoder_tower(
embed_dim,
multimodal_cfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As per the Encoder/Decoder Module above, with those split, this can be split to have text(_encoder) + text_decoder

multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
act_layer = QuickGELU if quick_gelu else nn.GELU
norm_layer = (
LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
)

decoder = MultimodalTransformer(
context_length=multimodal_cfg.context_length,
width=multimodal_cfg.width,
heads=multimodal_cfg.heads,
layers=multimodal_cfg.layers,
ls_init_value=multimodal_cfg.ls_init_value,
output_dim=embed_dim,
act_layer=act_layer,
norm_layer=norm_layer,
)

return decoder


class CoCa(nn.Module):
def __init__(
self,
embed_dim,
multimodal_cfg: MultimodalCfg,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just noticed this, embed_dim isn't used for CoCa as it's taken from multimodal_cfg.latent_dim ... a little bit weird to have the values in cfg, and the arg, and then not use it.. hmmm

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't look like the MultimodalTransformer tower uses the latent_dim itself, so should that just be the determined by the cfg['embed_dim'] like the other models and remove multimodal_cfg['latent_dim'] ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it resolved ? if not can you create an issue for it ?

text_cfg: CLIPTextCfg,
vision_cfg: CLIPVisionCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
pad_id: int = 0,
):
super().__init__()
multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg
vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg

self.text = _build_text_tower(
embed_dim=embed_dim,
text_cfg=text_cfg,
quick_gelu=quick_gelu,
cast_dtype=cast_dtype,
)

vocab_size = (
text_cfg.vocab_size # for hf models
if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None
else text_cfg.vocab_size
)

self.visual = _build_vision_tower(
embed_dim=embed_dim,
vision_cfg=vision_cfg,
quick_gelu=quick_gelu,
cast_dtype=cast_dtype,
)

self.text_decoder = _build_text_decoder_tower(
vocab_size,
multimodal_cfg=multimodal_cfg,
quick_gelu=quick_gelu,
cast_dtype=cast_dtype,
)

self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.pad_id = pad_id

@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.visual.set_grad_checkpointing(enable)
self.text.set_grad_checkpointing(enable)
self.text_decoder.set_grad_checkpointing(enable)

def _encode_image(self, images, normalize=True):
image_latent, tokens_embs = self.visual(images)
image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent
return image_latent, tokens_embs

def _encode_text(self, text, normalize=True):
text = text[:, :-1] # make space for CLS token
text_latent, token_emb = self.text(text)
text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent
return text_latent, token_emb

def encode_image(self, images, normalize=True):
image_latent, _ = self._encode_image(images, normalize=normalize)
return image_latent

def encode_text(self, text, normalize=True):
text_latent, _ = self._encode_text(text, normalize=normalize)
return text_latent

def forward(self, image, text):
text_latent, token_embs = self._encode_text(text)
image_latent, image_embs = self._encode_image(image)

# TODO: add assertion to avoid bugs?
labels = text[:, -token_embs.shape[1]:]

logits = self.text_decoder(image_embs, token_embs)
return {
"image_features": image_latent,
"text_features": text_latent,
"logits": logits,
"labels": labels,
"logit_scale": self.logit_scale.exp()
}

def generate(
self,
image,
text,
seq_len,
max_seq_len=77,
mask_prob=0.0,
temperature=1.,
filter_logits_fn=top_k,
filter_thres=0.9,
min_p_pow=2.0,
min_p_ratio=0.02,
):

assert mask_prob < 1, "mask_prob must be smaller than 1."

was_training = self.training
num_dims = len(text.shape)

if num_dims == 1:
text = text[None, :]

_, t = text.shape
self.eval()
out = text

for _ in range(seq_len):
x = out[:, -max_seq_len:]

# TODO: adjust for dict output
logits = self(image, x)["logits"][:, -1]

if filter_logits_fn in {top_k, top_p}:
filtered_logits = filter_logits_fn(logits, thres=filter_thres)
probs = F.softmax(filtered_logits / temperature, dim=-1)

elif filter_logits_fn is top_a:
filtered_logits = filter_logits_fn(
logits, min_p_pow=min_p_pow, min_p_ratio=min_p_ratio
)
probs = F.softmax(filtered_logits / temperature, dim=-1)

sample = torch.multinomial(probs, 1)

out = torch.cat((out, sample), dim=-1)

out = out[:, t:]

if num_dims == 1:
out = out.squeeze(0)

self.train(was_training)
return out
29 changes: 28 additions & 1 deletion src/open_clip/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
resize_pos_embed, get_cast_dtype
from .coca_model import CoCa
from .loss import ClipLoss, CoCaLoss
from .openai import load_openai_model
from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model, download_pretrained_from_hf
from .transform import image_transform, AugmentationCfg
Expand Down Expand Up @@ -177,7 +179,10 @@ def create_model(
if custom_text:
if is_hf_model:
model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf
model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
if "coca" in model_name:
model = CoCa(**model_cfg, cast_dtype=cast_dtype)
else:
model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
else:
model = CLIP(**model_cfg, cast_dtype=cast_dtype)

Expand Down Expand Up @@ -216,6 +221,28 @@ def create_model(
return model


def create_loss(args):
if "coca" in args.model.lower():
return CoCaLoss(
caption_loss_weight=args.coca_caption_loss_weight,
clip_loss_weight=args.coca_contrastive_loss_weight,
local_loss=args.local_loss,
gather_with_grad=args.gather_with_grad,
cache_labels=True,
rank=args.rank,
world_size=args.world_size,
use_horovod=args.horovod,
)
return ClipLoss(
local_loss=args.local_loss,
gather_with_grad=args.gather_with_grad,
cache_labels=True,
rank=args.rank,
world_size=args.world_size,
use_horovod=args.horovod,
)


def create_model_and_transforms(
model_name: str,
pretrained: Optional[str] = None,
Expand Down
37 changes: 37 additions & 0 deletions src/open_clip/generation_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from math import ceil
import torch
from torch import nn
import torch.nn.functional as F


def exists(val):
return val is not None


def top_p(logits, thres=0.9):
# nucleus
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

sorted_indices_to_remove = cum_probs > (1 - thres)
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = 0

sorted_logits[sorted_indices_to_remove] = float('-inf')
return sorted_logits.scatter(1, sorted_indices, sorted_logits)


def top_k(logits, thres=0.9):
k = ceil((1 - thres) * logits.shape[-1])
val, ind = torch.topk(logits, k)
probs = torch.full_like(logits, float('-inf'))
probs.scatter_(1, ind, val)
return probs


def top_a(logits, min_p_pow=2.0, min_p_ratio=0.02):
probs = F.softmax(logits, dim=-1)
limit = torch.pow(torch.max(probs), min_p_pow) * min_p_ratio
logits[probs < limit] = float('-inf')
logits[probs >= limit] = 1
return logits
Loading