# Flickr30k logical composition retrieval example

In [None]:
import torch
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

from datasets import load_dataset
from tqdm.auto import tqdm

from subspaces import ridge_projector, join
from model import TransformerSubspaceEmbedder
from train_nli import NLITrainingData

CACHE_DIR = "./.cache"

device = "cuda" if torch.cuda.is_available() else "cpu"

### Load SNLI-trained model

In [None]:
train_data = NLITrainingData.load("./runs/all-mpnet-base-v2_128x128_context35.pt")

model = TransformerSubspaceEmbedder(
    train_data.base_model_name, train_data.N, train_data.D, train_data.lbd, two_way=False, cache_dir=CACHE_DIR,
)
model.load_state_dict(train_data.state_dict, strict=False)
model.eval()
model.to(device)

### Load Flickr30k

In [None]:
dataset = load_dataset("nlphuji/flickr30k", cache_dir=CACHE_DIR, split="test")
captions, img_indices = [], []

for i in range(len(dataset)):
    cs = dataset[i]["caption"]
    captions.extend(cs)
    img_indices.extend([i]*len(cs))

### Compute caption subspaces (smooth projectors)

In [None]:
B = 256
chunked_captions = [captions[i:i + B] for i in range(0, len(captions), B)]

P_captions = []
for chunk in tqdm(chunked_captions):
    with torch.no_grad():
        P = model.encode(chunk, max_length=train_data.max_length, device=device)
    P_captions.append(P)
P_captions = torch.cat(P_captions)

### Retrieval with logical composition of queries

In [None]:
text2projector = lambda text : model.encode([text], train_data.max_length, device=device)
I = torch.eye(model.D, device=P.device)[None]

P1 = text2projector("people")
P2 = text2projector("sitting")
P3 = text2projector("on a bench")
P_query = P1 @ P2 @ P3 # people AND sitting AND on a bench

P1 = text2projector("food being prepared")
P2 = text2projector("outdoors on the grill")
P_query = P1 @ P2 # food being prepared AND outdoors on the grill

P1 = text2projector("an animal interacting with a human")
P2 = text2projector("in a zoo")
P_query = P1 @ P2 # an animal interacting with a human AND in a zoo

P1 = text2projector("a person walking")
P2 = text2projector("on the sidewalk")
P_query = P1 @ (I - P2) # a person walking AND NOT on the sidewalk
P_query = P1 @ P2  # a person walking AND on the sidewalk

P1 = text2projector("a person riding a bicycle")
P2 = text2projector("a person walking a dog")
P_query = join(P1[0], P2[0], train_data.lbd) # a person riding a bicycle OR a person walking a dog

P1 = text2projector("a group of people")
P2 = text2projector("military")
P_query = P1 @ P2 # a group of people AND military

P1 = text2projector("a child playing outdoors")
P2 = text2projector("near a road")
P_query = P1 @ (I - P2) # a child playing outdoors AND NOT near a road
P_query = P1 @ P2 # a child playing outdoors AND near a road

P1 = text2projector("a man")
P2 = text2projector("on a boat")
P3 = text2projector("is fishing")
P_query = P1 @ (I - P2) @ P3 # a man AND NOT on a boat AND is fishing

P1 = text2projector("a person playing the violin")
P2 = text2projector("standing in the street")
P_query = P1 @ (I - P2) # a person playing the violin AND NOT standing in the street

P1 = text2projector("a man and a surfboard")
P2 = text2projector("is surfing")
P_query = P1 @ (I - P2) # a man and a surfboard AND NOT is surfing
P_query = P1 @ P2 # a man and a surfboard AND is surfing

P_captions = P_captions.to(device)
scores = torch.einsum("bii->b", P_query @ P_captions) / torch.einsum("bii->b", P_captions)
idx = torch.argsort(scores, descending=True)
for i in range(10):
    print(f"Normalized inclusion score: {scores[idx[i]].item():.3f} | Caption: {captions[idx[i]]} (idx {idx[i]})")

In [None]:
display_indices = []
i = 0
while len(display_indices) < 6:
    if img_indices[idx[i]] not in display_indices:
        display_indices.append(img_indices[idx[i]])
    i += 1

fig = plt.figure(figsize=(15, 2))
gs = gridspec.GridSpec(1, 6, wspace=0.05)

for i in range(6):
    ax = fig.add_subplot(gs[0, i])
    img = dataset[display_indices[i]]["image"]
    ax.imshow(img)
    ax.axis("off")
    ax.set_aspect('auto')
    print(dataset[display_indices[i]]["caption"])
#plt.savefig("./results.pdf", bbox_inches='tight', pad_inches=0)