# [CoCa: Contrastive Captioners are Image-Text Foundation Models](https://arxiv.org/pdf/2205.01917.pdf)

In [None]:
https://github.com/jiaowoguanren0615/CoCa-Pytorch/tree/main/CoCa

In [4]:
# !pip3 install coca-pytorch

In [16]:
import torch
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
print('Using device', device)

# import vision transformer
# from vit_pytorch.simple_vit_with_patch_dropout import SimpleViT
from vit_pytorch import ViT
from vit_pytorch.extractor import Extractor

vit = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    # patch_dropout = 0.5  # https://arxiv.org/abs/2212.00794
)

vit = Extractor(vit, return_embeddings_only = True, detach = False)
vit
# extractor will enable it so the vision transformer returns its embeddings

Using device cpu


Extractor(
  (vit): ViT(
    (to_patch_embedding): Sequential(
      (0): Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=32, p2=32)
      (1): LayerNorm((3072,), eps=1e-05, elementwise_affine=True)
      (2): Linear(in_features=3072, out_features=1024, bias=True)
      (3): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    )
    (dropout): Dropout(p=0.0, inplace=False)
    (transformer): Transformer(
      (norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (layers): ModuleList(
        (0-5): 6 x ModuleList(
          (0): Attention(
            (norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            (attend): Softmax(dim=-1)
            (dropout): Dropout(p=0.0, inplace=False)
            (to_qkv): Linear(in_features=1024, out_features=3072, bias=False)
            (to_out): Sequential(
              (0): Linear(in_features=1024, out_features=1024, bias=True)
              (1): Dropout(p=0.0, inplace=False)
            )
          

In [17]:
# import CoCa and instantiate it
from coca_pytorch.coca_pytorch import CoCa

coca = CoCa(
    dim = 512,                     # model dimension
    img_encoder = vit,             # vision transformer - image encoder, returning image embeddings as (batch, seq, dim)
    image_dim = 1024,              # image embedding dimension, if not the same as model dimensions
    num_tokens = 20000,            # number of text tokens
    unimodal_depth = 6,            # depth of the unimodal transformer
    multimodal_depth = 6,          # depth of the multimodal transformer
    dim_head = 64,                 # dimension per attention head
    heads = 8,                     # number of attention heads
    caption_loss_weight = 1.,      # weight on the autoregressive caption loss
    contrastive_loss_weight = 1.,  # weight on the contrastive loss between image and text CLS embeddings
).to(device)
coca

CoCa(
  (token_emb): Embedding(20000, 512)
  (img_encoder): Extractor(
    (vit): ViT(
      (to_patch_embedding): Sequential(
        (0): Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=32, p2=32)
        (1): LayerNorm((3072,), eps=1e-05, elementwise_affine=True)
        (2): Linear(in_features=3072, out_features=1024, bias=True)
        (3): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      )
      (dropout): Dropout(p=0.0, inplace=False)
      (transformer): Transformer(
        (norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (layers): ModuleList(
          (0-5): 6 x ModuleList(
            (0): Attention(
              (norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (attend): Softmax(dim=-1)
              (dropout): Dropout(p=0.0, inplace=False)
              (to_qkv): Linear(in_features=1024, out_features=3072, bias=False)
              (to_out): Sequential(
                (0): Linear(in_features=1024, out_f

In [None]:
# mock text and images

text = torch.randint(0, 20000, (4, 512)).to(device)
images = torch.randn(4, 3, 256, 256).to(device)

# train by giving CoCa your text and images with `return_loss = True`

loss = coca(
    text = text,
    images = images,
    return_loss = True  # set this to True to get the full caption + contrastive loss
)

loss.backward()

In [12]:
# do the above for as much text and images...
# then you can get the caption logits as so
logits = coca(
    text = text,
    images = images
) # (4, 512, 20000)
logits.shape

torch.Size([4, 512, 20000])

In [14]:
# and the CLIP-like text and image embeddings as

text_embeds, image_embeds = coca(
    text = text,
    images = images,
    return_embeddings = True
) # (4, 512), (4, 512)
text_embeds.shape, image_embeds.shape

(torch.Size([4, 512]), torch.Size([4, 512]))