In [8]:
import torch 
import torch.nn as nn
from typing import List, Union, Tuple, Dict, Optional

import numpy as np

from models import FusionModule

from PIL import Image
import requests
from transformers import (
    VisionTextDualEncoderModel,
    VisionTextDualEncoderProcessor,
    AutoImageProcessor,
    AutoTokenizer,
)
from transformers import VisionTextDualEncoderConfig, AutoConfig
from transformers.models.clip.modeling_clip import CLIPOutput

from models import FusionModule

In [2]:
vision_pretrained = "google/vit-base-patch16-224"
text_pretraind = "bert-base-uncased"

tokenizer = AutoTokenizer.from_pretrained(text_pretraind)
image_processor = AutoImageProcessor.from_pretrained(vision_pretrained)
processor = VisionTextDualEncoderProcessor(image_processor, tokenizer)

text_config = AutoConfig.from_pretrained(text_pretraind).to_dict()
vision_config = AutoConfig.from_pretrained(vision_pretrained).to_dict()

config = VisionTextDualEncoderConfig(
    text_config=text_config, 
    vision_config=vision_config,
    projection_dim=512,
)

dual_encoder = VisionTextDualEncoderModel(config)
fusioner = FusionModule()

In [3]:
# contrastive training
urls = [
    "http://images.cocodataset.org/val2017/000000039769.jpg",
    "https://farm3.staticflickr.com/2674/5850229113_4fe05d5265_z.jpg",
]
images = [Image.open(requests.get(url, stream=True).raw) for url in urls]
inputs = processor(
    text=["a photo of a cat", "a photo of a dog"], 
    images=images, return_tensors="pt", padding=True
)

# inference
encoder_outputs: CLIPOutput = dual_encoder(**inputs)
text_embeds = encoder_outputs.text_model_output.last_hidden_state
image_embeds = encoder_outputs.vision_model_output.last_hidden_state

outputs = fusioner(text_embeds, image_embeds)