In [31]:
import torch 
import torch.nn as nn
from typing import List, Union, Tuple, Dict, Optional

import numpy as np

In [30]:
def exists(val):
    return val is not None

class BaseModel(nn.Module):
    def __init__(
        self, 
        image_encoder:Optional[nn.Module]=None, 
        text_encoder:Optional[nn.Module]=None, 
        image_processor:Optional[nn.Module]=None,
    ):
        super().__init__()
        
        self.image_encoder = image_encoder
        self.text_encoder = text_encoder
        self.image_processor = image_processor
        
    def embed_image(
        self, 
        images:Optional[Image]=None, 
        image_tokens:Optional[torch.Tensor]=None, 
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Encode images into embeddings with image_encoder passed at init. 
        Can also accept precomputed image embeddings.
        """ 
        
        assert not (exists(images) and exists(image_tokens)), "Can only pass one of images or image_tokens"
        
        if exists(images):
            assert exists(self.image_encoder), "Must pass image_encoder at init for image encoding"
            image_inputs = self.image_processor(images=image, return_tensors='pt')
            outputs = self.image_encoder(**image_inputs)
            image_embeds, image_tokens = outputs.last_hidden_state, outputs.pooler_output

        return image_embeds, image_tokens
    

In [18]:
from transformers import CLIPVisionModel, CLIPProcessor
from PIL import Image
import requests

image_encoder = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

Some weights of the model checkpoint at openai/clip-vit-base-patch32 were not used when initializing CLIPVisionModel: ['text_model.encoder.layers.0.layer_norm2.bias', 'text_model.final_layer_norm.bias', 'text_model.encoder.layers.7.layer_norm2.bias', 'text_model.encoder.layers.7.self_attn.k_proj.bias', 'text_model.encoder.layers.7.self_attn.k_proj.weight', 'text_model.encoder.layers.8.layer_norm2.weight', 'text_model.encoder.layers.5.self_attn.out_proj.bias', 'text_model.encoder.layers.8.layer_norm2.bias', 'text_model.encoder.layers.5.mlp.fc1.bias', 'text_model.encoder.layers.2.self_attn.v_proj.bias', 'text_model.encoder.layers.4.mlp.fc1.bias', 'text_model.encoder.layers.1.layer_norm2.weight', 'text_model.encoder.layers.5.self_attn.v_proj.bias', 'text_model.embeddings.token_embedding.weight', 'text_model.encoder.layers.2.layer_norm1.weight', 'text_model.encoder.layers.6.self_attn.v_proj.bias', 'text_model.encoder.layers.6.mlp.fc1.bias', 'text_model.encoder.layers.10.self_attn.q_proj.we

In [None]:
model = BaseModel(
    
)

In [19]:
inputs = processor(images=image, return_tensors='pt')
outputs = image_encoder(**inputs)

In [29]:
outputs.last_hidden_state.shape, outputs.pooler_output.shape

(torch.Size([1, 50, 768]), torch.Size([1, 768]))