In [None]:
import matplotlib.pyplot as plt
import numpy as np
from   PIL import Image
import requests
from   sentence_transformers import SentenceTransformer
from   sentence_transformers.util import cos_sim
from   transformers import (
    CLIPModel, CLIPProcessor,
    DPRContextEncoder, DPRContextEncoderTokenizer, 
    DPRQuestionEncoder, DPRQuestionEncoderTokenizer)
import torch

In [None]:
mod = SentenceTransformer('all-mpnet-base-v2')

In [2]:
sentences = [
    'it caught him off guard that space smelled of seared steak',
    'she could not decide between painting her teeth or brushing her '
    'nails',
    "he thought there'd be sufficient time if he hid his watch",
    'the bees decided to have a mutiny against their queen',
    'the sign said there was road work ahead so she decided to speed '
    'up',
    "on a scale of one to ten what's your favorite flavor of color?",
    'flying stinging insects rebelled in opposition to the matriarch']

In [None]:
embeddings = mod.encode(sentences)
embeddings.shape

In [None]:
scores = cos_sim(embeddings[-1], embeddings[:-1])
scores

In [None]:
sentences[scores.argmax().item()]

### Question Answering

In [None]:
pretrained = 'facebook/dpr-ctx_encoder-single-nq-base'
ctx_mod = DPRContextEncoder.from_pretrained(pretrained)
ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained(pretrained)

pretrained = pretrained.replace('ctx', 'question')
question_mod = DPRQuestionEncoder.from_pretrained(pretrained)
question_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
    pretrained)

In [4]:
questions = [
    'what is the capital of australia?',
    'what is the best selling sci-fi book?',
    'how many searches are performed on google?']

contexts = [
    'canberra is the capital city of australia',
    'what is the capital of australia?',
    'what country is paris the capital of?',
    'sci-fi is a popular genre beloved my millions',
    'the best selling sci-fi book is that atrocity by L. Ron Hubbard',
    'google is a popular search engine',
    'what is the best selling sci-fi book?',
    'how many searches are performed on google?',
    'google servers more than 2 trillion queries per year']

In [None]:
xb_tokens = ctx_tokenizer(
    contexts, max_length=256, padding='max_length', return_tensors='pt')
xb = ctx_mod(**xb_tokens)

xq_tokens = question_tokenizer(questions, 
                               max_length=256, 
                               padding='max_length', 
                               return_tensors='pt')
xq = question_mod(**xb_tokens)

In [None]:
xq.keys()

In [None]:
xq.pooler_output.shape, xb.pooler_output.shape

In [None]:
for i, xq_vec in enumerate(xq.pooler_output):
    probs = cos_sim(xq_vec, xb.pooler_output)
    argmax = torch.argmax(probs)
    print(questions[i])
    print('  ', contexts[argmax], '\n')

### Text to Image

In [None]:
pretrained = 'openai/clip-vit-base-patch32'
mod = CLIPModel.from_pretrained(pretrained)
processor = CLIPProcessor.from_pretrained(pretrained)

In [None]:
urls = [
    'https://images.unsplash.com/photo-1576201836106-db1758fd1c97?ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&ixlib=rb-1.2.1&auto=format&fit=crop&w=400&q=80',
    'https://images.unsplash.com/photo-1591294100785-81d39c061468?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=300&q=80',
    'https://images.unsplash.com/photo-1548199973-03cce0bbc87b?ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&ixlib=rb-1.2.1&auto=format&fit=crop&w=400&q=80'
]

In [None]:
images = [Image.open(requests.get(url, stream=True).raw) 
          for url in urls]
for img in images:
    plt.show(plt.imshow(np.asarray(img)))

In [None]:
captions = ['a dog hiding behind a tree',
            'two dogs running',
            'a dog running',
            'a cucumber on a tree',
            'trees in the park',
            'a dog eating a cucumber']

In [None]:
inputs = processor(
    text=captions, images=images, return_tensors='pt', padding=True)

In [None]:
outputs mod(**inputs)
outputs.keys()

In [None]:
probs = outputs.logits_per_image.argmax(dim=1)

for i, img in enumerate(images):
    argmax = probs[i].item()
    print(captions[argmax])
    plt.show(plt.imshow(np.asarray(img)))

In [None]:
outputs.text_embeds.shape, outputs.image_embeds.shape

In [None]:
xq = outputs.text_embeds[0] # 'a dog hiding behind a tree'
xb = outputs.image_embeds
sim = cos_sim(xq, xb)
sim

In [None]:
pred = sim.argmax().item()
pred