In [1]:
from transformers import VisionTextDualEncoderModel,VisionTextDualEncoderProcessor,ViTFeatureExtractor,BertTokenizer
from PIL import Image
import torch
from pathlib import Path
import pandas as pd
import random
from torch.utils.data import Dataset, DataLoader
import gc
import matplotlib.pyplot as plt

In [2]:
%%capture
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if Path('duo_constrated_tokenizer').exists() and Path('duo_constrated_model').exists() and Path('duo_constrated_feature_extractor').exists():
    tokenizer = BertTokenizer.from_pretrained('duo_constrated_tokenizer')
    feature_extractor = ViTFeatureExtractor.from_pretrained('duo_constrated_feature_extractor')
    processor = VisionTextDualEncoderProcessor(feature_extractor, tokenizer)
    model = VisionTextDualEncoderModel.from_pretrained('duo_constrated_model')
else:
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
    processor = VisionTextDualEncoderProcessor(feature_extractor, tokenizer)
    model = VisionTextDualEncoderModel.from_vision_text_pretrained("google/vit-base-patch16-224", "bert-base-uncased")

model.to(device)
model.eval()

In [3]:
training_data = pd.read_excel(r'/media/delta/S/trainig_data.xlsx')

In [4]:
class ImageData(Dataset):
    def __init__(self,ds, feature_extractor):
        self.images = ds['path to image'].values.tolist()
        self.captions = ds['image caption'].values.tolist()
        self.pixel_values = []
        for i in self.images:
            image = Image.open(i).convert('RGB')
            image = feature_extractor(image, return_tensors='pt').pixel_values.to(device)
            self.pixel_values.append(image)

    def __len__(self):
        return len(self.pixel_values)

    def __getitem__(self, idx):
        return  self.pixel_values[idx], self.images[idx], self.captions[idx]

In [5]:
def collate_fn(examples):
    pixel_values = torch.stack([example[0] for example in examples])
    images = [example[1]for example in examples]
    captions = [example[2]for example in examples]
    return {"pixel_values": pixel_values,'image_path':images, 'captions':captions}

In [6]:
asset_data = DataLoader(ImageData(training_data,feature_extractor),
                        batch_size=64, shuffle=False, collate_fn=collate_fn)

In [7]:
gc.collect()
torch.cuda.empty_cache()

In [None]:
captions = 'restraint ladder is out of date'   #Search an image using your own caption
scores = []
image_paths = []
caption = []
for _, batch in enumerate(asset_data):
    pixel_values = batch['pixel_values'].view(-1,3,224,224)
    image_path = batch['image_path']
    text = batch['captions']
    inputs = tokenizer(captions,padding='max_length',return_tensors='pt',max_length=128)
    inputs['pixel_values'] = pixel_values
    with torch.no_grad():
        outputs = model(**inputs.to(device))
    logits_per_image = outputs.logits_per_image.view(-1)
    ids = logits_per_image.argmax().unsqueeze(0).item()
    probability = logits_per_image[ids].item()
    image_path = image_path[ids]
    text = text[ids]
    image_paths.append(image_path)
    scores.append(probability)
    caption.append(text)
ids = torch.topk(torch.tensor(scores),1).indices
scores = [scores[i] for i in ids]
image_paths  = [image_paths[ids[0]]][0]
text = [caption[ids[0]]][0]
print('actual_caption:',text)
print('requested caption',captions)
Image.open(image_paths)