# CLIP utilities

In [None]:
#|default_exp ml.clip

In [None]:
#|hide
from fastcore.test import *
from nbdev.showdoc import *

In [None]:
#|export

import clip
import torch
import torch.nn as nn
from torchvision import transforms
from bellem.ml.vision import TorchVisionTransform

In [None]:
#|export

def load_clip_preprocess(clip_model_name):
    from clip import clip
    return clip.load(clip_model_name, device='cpu')[1]


In [None]:
#|export

def make_tfms_from_clip_preprocess(clip_preprocess):
    item_tfms = TorchVisionTransform(transforms.Compose(clip_preprocess.transforms[:-2]))
    batch_tfms = TorchVisionTransform(transforms.Compose(clip_preprocess.transforms[-2:]))
    return item_tfms, batch_tfms

In [None]:
#|export

class ClipClassificationHead(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.logit_scale = nn.Parameter(clip_model.logit_scale.detach().clone(), requires_grad=True) 
    
    def forward(self, image_features, text_features):
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)
        logits = self.logit_scale.exp() * (image_features @ text_features.t())
        return logits

In [None]:
#|export

class ClipZeroShotClassifier(nn.Module):
    def __init__(self, clip_model, class_descriptions):
        super().__init__()
        self.clip_model = clip_model
        self.head = ClipClassificationHead(clip_model)
        with torch.inference_mode():
            ctf = self.compute_text_features(class_descriptions)
        self.class_text_features = nn.Parameter(ctf, requires_grad=False)
    
    def forward(self, image):
        image_features = self.clip_model.encode_image(image)
        return self.head(image_features, self.class_text_features)

    def compute_text_features(self, texts):
        device = next(self.clip_model.parameters()).device
        text_tokens = clip.tokenize(texts)
        text_features = self.clip_model.encode_text(text_tokens.to(device)).float()
        return text_features


In [None]:
#|hide
import nbdev; nbdev.nbdev_export()