In [1]:
!pip install -U open_clip_torch



In [2]:
import open_clip
import pprint
pprint.pp([p for p in open_clip.list_pretrained()])

[('RN50', 'openai'),
 ('RN50', 'yfcc15m'),
 ('RN50', 'cc12m'),
 ('RN101', 'openai'),
 ('RN101', 'yfcc15m'),
 ('RN50x4', 'openai'),
 ('RN50x16', 'openai'),
 ('RN50x64', 'openai'),
 ('ViT-B-32', 'openai'),
 ('ViT-B-32', 'laion400m_e31'),
 ('ViT-B-32', 'laion400m_e32'),
 ('ViT-B-32', 'laion2b_e16'),
 ('ViT-B-32', 'laion2b_s34b_b79k'),
 ('ViT-B-32', 'datacomp_xl_s13b_b90k'),
 ('ViT-B-32', 'datacomp_m_s128m_b4k'),
 ('ViT-B-32', 'commonpool_m_clip_s128m_b4k'),
 ('ViT-B-32', 'commonpool_m_laion_s128m_b4k'),
 ('ViT-B-32', 'commonpool_m_image_s128m_b4k'),
 ('ViT-B-32', 'commonpool_m_text_s128m_b4k'),
 ('ViT-B-32', 'commonpool_m_basic_s128m_b4k'),
 ('ViT-B-32', 'commonpool_m_s128m_b4k'),
 ('ViT-B-32', 'datacomp_s_s13m_b4k'),
 ('ViT-B-32', 'commonpool_s_clip_s13m_b4k'),
 ('ViT-B-32', 'commonpool_s_laion_s13m_b4k'),
 ('ViT-B-32', 'commonpool_s_image_s13m_b4k'),
 ('ViT-B-32', 'commonpool_s_text_s13m_b4k'),
 ('ViT-B-32', 'commonpool_s_basic_s13m_b4k'),
 ('ViT-B-32', 'commonpool_s_s13m_b4k'),
 ('ViT-

In [3]:
import torch
import torch.nn as nn
import open_clip
import torch.nn.functional as F

In [4]:
MODEL_NAME = "MobileCLIP2-B"
PRETRAINED = "dfndr2b"

base_model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms(MODEL_NAME, pretrained=PRETRAINED)
tokenizer = open_clip.get_tokenizer(MODEL_NAME)
print(base_model)

device = "cuda" if torch.cuda.is_available() else "cpu"

CustomTextCLIP(
  (visual): TimmModel(
    (trunk): VisionTransformer(
      (patch_embed): HybridEmbed(
        (backbone): ConvStem(
          (0): ConvNormAct(
            (conv): Conv2d(3, 192, kernel_size=(4, 4), stride=(4, 4), bias=False)
            (bn): BatchNormAct2d(
              192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
              (drop): Identity()
              (act): GELU(approximate='none')
            )
          )
          (1): ConvNormAct(
            (conv): Conv2d(192, 192, kernel_size=(2, 2), stride=(2, 2), bias=False)
            (bn): BatchNormAct2d(
              192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
              (drop): Identity()
              (act): GELU(approximate='none')
            )
          )
          (2): ConvNormAct(
            (conv): Conv2d(192, 768, kernel_size=(2, 2), stride=(2, 2))
            (bn): Sequential()
          )
        )
        (proj): Identity()
      )
      (pos_dr

In [5]:
class CoOpPromptClassifier(nn.Module):
    def __init__(self, base_model, tokenizer, classnames, n_ctx=16, device="cpu"):
        super().__init__()
        self.base_model = base_model
        self.tokenizer = tokenizer
        self.classnames = classnames
        self.n_ctx = n_ctx
        self.device = device

        # Make it work for any CLIP-like model
        if hasattr(self.base_model, 'text'):
            self.text_tower = self.base_model.text
        else:
            self.text_tower = self.base_model
        if hasattr(self.base_model, 'visual'):
            self.visual_tower = self.base_model.visual
        else:
            self.visual_tower = self.base_model

        self.embed_size = self.text_tower.token_embedding.shape[1]
        self.ctx = nn.Parameter(torch.randn(n_ctx, self.embed_size))

        with torch.no_grad():
            self.class_token_ids = self.tokenizer(classnames).to(self.device)
            self.class_token_embeddings = self.text_tower.token_embedding(self.class_token_ids)
        
        for p in self.model.parameters():
            p.requires_grad = False

    
    def forward_text_features(self):
        torch.cat([self.ctx, self.base_model.text.token_embedding(self.class_token_ids)], dim=1)

In [None]:
class CoOpPromptRetrieval(nn.Module):
    def __init__(self, base_model, tokenizer, n_ctx=16, device="cpu"):
        super().__init__()
        self.base_model = base_model.to(device)
        self.tokenizer = tokenizer
        self.n_ctx = n_ctx
        self.device = device

        # Make it work for any CLIP-like model
        if hasattr(self.base_model, 'text'):
            self.text_tower = self.base_model.text
        else:
            self.text_tower = self.base_model
        if hasattr(self.base_model, 'visual'):
            self.visual_tower = self.base_model.visual
        else:
            self.visual_tower = self.base_model

        self.embedding_size = self.text_tower.token_embedding.weight.shape[1]
        self.ctx = nn.Parameter(0.02 * torch.randn(n_ctx, self.embedding_size))

        for p in self.base_model.parameters():
            p.requires_grad = False
    
    def embeddings(self, text):
        return self.text_tower.token_embedding(self.tokenizer(text).to(self.device))

    
    def forward(self, image, text):
        # Tokenize if needed
        if isinstance(text, list) or (isinstance(text, torch.Tensor) and text.dtype == torch.object):
            tokens = self.tokenizer(text).to(self.device)          # [B, L]
        else:
            tokens = text.to(self.device)                          # already token ids: [B, L]

        B, L = tokens.shape

        # Text: inject learned context after BOS (position 0)
        tok = self.text_tower.token_embedding(tokens)              # [B, L, D]
        x = tok
        n = min(self.n_ctx, L - 2)                                 # keep BOS/EOS
        if n > 0:
            x = x.clone()
            x[:, 1:1 + n, :] = self.ctx[:n].unsqueeze(0).expand(B, n, -1)

        x = x + self.text_tower.positional_embedding[:L].unsqueeze(0)
        x = x.permute(1, 0, 2)                                     # [L, B, D]
        x = self.text_tower.transformer(x)
        x = x.permute(1, 0, 2)                                     # [B, L, D]
        x = self.text_tower.ln_final(x)

        eos_idx = tokens.argmax(dim=-1)                            # EOS position
        text_features = x[torch.arange(B, device=self.device), eos_idx]
        text_features = text_features @ self.text_tower.text_projection

        # Image: standard encode
        image = image.to(self.device)
        image_features = self.base_model.encode_image(image)

        return image_features, text_features

    
    def logits(self, image, text):
        image_features, text_features = self.forward(image, text)
        logit_scale = self.base_model.logit_scale.exp()
        logits = logit_scale * image_features @ text_features.T
        return logits

In [7]:
coop_retrieval_model = CoOpPromptRetrieval(base_model, tokenizer, n_ctx=16, device=device)
optimizer = torch.optim.AdamW([coop_retrieval_model.ctx], lr=1e-3)

In [8]:
from datasets import load_dataset
train_ds = load_dataset("AnyModal/flickr30k", split="train")
val_ds = load_dataset("AnyModal/flickr30k", split="validation")

train_ds = train_ds.with_transform(
    lambda ex: {
        "image":   [preprocess_train(img) for img in ex["image"]],
        "alt_text":[caps[0] for caps in ex["alt_text"]],
    }
)
val_ds = val_ds.with_transform(
    lambda ex: {
        "image":   [preprocess_val(img) for img in ex["image"]],
        "alt_text":[caps[0] for caps in ex["alt_text"]],
    }
)


In [9]:
from helper_functions import get_data, fit_retrieval, clip_contrastive_loss_from_logits

BATCH_SIZE = 32
EPOCHS = 5

train_dl, val_dl = get_data(train_ds, val_ds, bs=BATCH_SIZE, n_shot=8)
fit_retrieval(EPOCHS, coop_retrieval_model, clip_contrastive_loss_from_logits, optimizer, train_dl, val_dl)

Epoch 0:   0%|          | 0/1 [00:00<?, ?it/s]

                                                      

Epoch 0, val_loss 173.32937399358204


                                                      

Epoch 1, val_loss 173.87181479549972


                                                      

Epoch 2, val_loss 174.71024691188595


                                                      

Epoch 3, val_loss 175.717343422553


                                                      

Epoch 4, val_loss 176.7060768082297
