### Load in packages and model

In [1]:
%cd ../../..

/home/nils/NILS/Master/DL2/DL2-ZeroVis


In [2]:
from src.fromage_inf.inf_utils import PromptParser
import pickle
import itertools
import torch
from collections import Counter
from PIL import Image
import torchvision.transforms as transforms
import nltk.translate.bleu_score as BLEU
import torch.nn.functional as F
from transformers import CLIPTextModel, AutoTokenizer

In [3]:
parser = PromptParser("src/fromage_inf/fromage_model/")

# Load the relations dictionary to make the relations.
relations = pickle.load(open("src/code/relations_dict.pkl", "rb"))

Using facebook/opt-6.7b for the language model.
Freezing the LM.
Initializing embedding for the retrieval token [RET] (id = 50266).


# Arithmetic Greedy

In [4]:
def recall(generated, ground_truth):
    # split the sentence into words
    words = generated.split()
    
    # count occurrences of each word in the sentence and in the gold label
    words_counter = Counter(words)
    truth_counter = Counter(ground_truth)

    true_positives = 0

    # for each unique word in the sentence, get the minimum count in the sentence and the gold label
    for word in words_counter:
        if word in truth_counter:
            true_positives += min(words_counter[word], truth_counter[word])
    
    recall = true_positives / sum(truth_counter.values())

    return recall

In [5]:
# Initialize the CLIP score metric with the same clip model as used in FROMAGe
clip_text = CLIPTextModel.from_pretrained('openai/clip-vit-large-patch14')
tokenizer = AutoTokenizer.from_pretrained('openai/clip-vit-large-patch14')

# Set smoothing function for bleu
chencherry = BLEU.SmoothingFunction()

total_scores = {}

for relation, values in relations.items():
    print(relation)
    print('=' * 120)

    amount_combi = 0
    recall5 = 0.
    bleu1 = 0.
    clip_s = 0.

    combinations = itertools.combinations(values, 2)

    for combo in combinations:
        tuple1, tuple2 = combo
         
        print('='*60)
        print("Arithmetic:")
        print("{} + ({} - {})".format(tuple1[0], tuple2[1], tuple2[0]))
        print("Expected result: {}".format(tuple1[1]))

        inp_image = parser.model.visual_embs[tuple1[0]] + (parser.model.visual_embs[tuple2[1]] - parser.model.visual_embs[tuple2[0]])

        # Add empty string to prevent error
        prompt = [inp_image,""]

        print('=' * 30)
        # num_words is set to 5 as zerocap uses beam 5 for its experiments
        model_outputs = parser.model.generate_for_images_and_texts(prompt, ret_scale_factor=0, num_words=5)

        print('Model generated outputs:')
        parser.display(model_outputs)

        # SCORES
        amount_combi += 1
        ground_truth = tuple1[1].split("/")[1].replace("_"," ").split()

        # Recall @ 5
        recall5 += recall(model_outputs[0], ground_truth)

        # BLEU-1
        bleu1 += BLEU.sentence_bleu([ground_truth], model_outputs[0].split(),  weights=(1.,0.), smoothing_function=chencherry.method1)

        # CLIP-s
        x1 = tokenizer(['Image of a {}'.format(ground_truth)], padding=False, return_tensors='pt')
        x2 = tokenizer(model_outputs[0], padding=False, return_tensors='pt')
        x1_tensor = clip_text(**x1).last_hidden_state.squeeze()
        x2_tensor = clip_text(**x2).last_hidden_state.squeeze()

        cos = F.normalize(x1_tensor) @ F.normalize(x2_tensor).T
        clip_s += torch.mean(cos.squeeze()).item()
        
        print('='*60)
        print("Arithmetic:")
        print("{} + ({} - {})".format(tuple2[0], tuple1[1], tuple1[0]))
        print("Expected result: {}".format(tuple2[1]))

        inp_image = parser.model.visual_embs[tuple2[0]] + (parser.model.visual_embs[tuple1[1]] - parser.model.visual_embs[tuple1[0]])

        # Add empty string to prevent error
        prompt = [inp_image,""]

        print('=' * 30)
        model_outputs = parser.model.generate_for_images_and_texts(prompt, ret_scale_factor=0, num_words=5)

        print('Model generated outputs:')
        parser.display(model_outputs)

        # SCORES
        amount_combi += 1
        ground_truth = tuple2[1].split("/")[1].replace("_"," ").split()

        # Recall @ 5
        recall5 += recall(model_outputs[0], ground_truth)

        # BLEU-1
        bleu1 += BLEU.sentence_bleu([ground_truth], model_outputs[0].split(),  weights=(1.,0.), smoothing_function=chencherry.method1)

        # CLIP-s
        x1 = tokenizer(['Image of a {}'.format(ground_truth)], padding=False, return_tensors='pt')
        x2 = tokenizer(model_outputs[0], padding=False, return_tensors='pt')
        x1_tensor = clip_text(**x1).last_hidden_state.squeeze()
        x2_tensor = clip_text(**x2).last_hidden_state.squeeze()

        cos = F.normalize(x1_tensor) @ F.normalize(x2_tensor).T
        clip_s += torch.mean(cos.squeeze()).item()

    # relation_outs[relation] = rel_list
    relation_scores = {"CLIP-s":clip_s/amount_combi,
                       "Recall@5": recall5/amount_combi,
                       "BLEU-1": bleu1/amount_combi}
    total_scores[relation] = relation_scores

Some weights of the model checkpoint at openai/clip-vit-large-patch14 were not used when initializing CLIPTextModel: ['vision_model.encoder.layers.22.mlp.fc2.weight', 'vision_model.encoder.layers.15.layer_norm1.weight', 'vision_model.post_layernorm.weight', 'vision_model.encoder.layers.12.self_attn.k_proj.weight', 'vision_model.encoder.layers.13.layer_norm1.weight', 'vision_model.encoder.layers.19.self_attn.q_proj.weight', 'vision_model.encoder.layers.16.self_attn.v_proj.bias', 'vision_model.encoder.layers.3.mlp.fc1.weight', 'vision_model.encoder.layers.0.mlp.fc1.bias', 'vision_model.encoder.layers.21.layer_norm2.weight', 'vision_model.encoder.layers.6.self_attn.v_proj.weight', 'vision_model.encoder.layers.16.mlp.fc2.bias', 'vision_model.encoder.layers.12.mlp.fc1.weight', 'vision_model.encoder.layers.1.layer_norm2.weight', 'vision_model.encoder.layers.23.mlp.fc2.bias', 'vision_model.encoder.layers.7.self_attn.k_proj.weight', 'vision_model.encoder.layers.20.mlp.fc1.weight', 'vision_mode

CEOs -> companies
Arithmetic:
CEOs/mark_zuckerberg + (companies/microsoft - CEOs/bill_gates)
Expected result: companies/facebook
Model generated outputs:
the logo is a vector
Arithmetic:
CEOs/bill_gates + (companies/facebook - CEOs/mark_zuckerberg)
Expected result: companies/microsoft
Model generated outputs:
the logo of the company
Arithmetic:
CEOs/mark_zuckerberg + (companies/tesla - CEOs/elon_musk)
Expected result: companies/facebook
Model generated outputs:
logo on a white
Arithmetic:
CEOs/elon_musk + (companies/facebook - CEOs/mark_zuckerberg)
Expected result: companies/tesla
Model generated outputs:
the logo of the company
Arithmetic:
CEOs/mark_zuckerberg + (companies/amazon - CEOs/jeff_bezos)
Expected result: companies/facebook
Model generated outputs:
the logo on the door
Arithmetic:
CEOs/jeff_bezos + (companies/facebook - CEOs/mark_zuckerberg)
Expected result: companies/amazon
Model generated outputs:
the logo of the company
Arithmetic:
CEOs/mark_zuckerberg + (companies/apple 

In [7]:
for relation, values in total_scores.items():
    print('='*30)
    print("Relationship: ")
    print(relation)
    print("Scores: \n")

    for key, val in values.items():
        print(key, ": ", val)

Relationship: 
CEOs -> companies
Scores: 

CLIP-s :  0.16118695959448814
Recall@5 :  0.0
BLEU-1 :  0.0
Relationship: 
flags -> capital
Scores: 

CLIP-s :  0.15218415749118183
Recall@5 :  0.0
BLEU-1 :  0.0
Relationship: 
food -> countries
Scores: 

CLIP-s :  0.15927558888991675
Recall@5 :  0.0
BLEU-1 :  0.0
Relationship: 
building -> countries
Scores: 

CLIP-s :  0.15548672221068824
Recall@5 :  0.017857142857142856
BLEU-1 :  0.004464285714285714
Relationship: 
flags -> leaders
Scores: 

CLIP-s :  0.1405576091673639
Recall@5 :  0.005555555555555556
BLEU-1 :  0.0022222222222222222
