In [1]:
import os
os.environ['CRYPTOGRAPHY_OPENSSL_NO_LEGACY'] = '1'

from typing import Callable, List, Optional, Tuple, Union
import numpy as np
import copy
import torch
torch.cuda.set_device(0)

from transformers import logging
logging.set_verbosity_error()

from PIL import Image
# !pip install matplotlib
import matplotlib.pyplot as plt

from fromage import models
from fromage import utils
from PIL import Image, ImageDraw, ImageFont, ImageOps

In [2]:
def trunc_caption(caption: str) -> str:
    # Truncate at period.
    trunc_index = caption.find('.') + 1
    if trunc_index < 0:
        trunc_index = caption.find('\n') + 1
    caption = caption[:trunc_index]
    return caption

def display_interleaved_outputs(model_outputs, one_img_per_ret=True):
    for output in model_outputs:
        if type(output) == str:
            print(output)
        elif type(output) == list:
            if one_img_per_ret:
                plt.figure(figsize=(3, 3))
                plt.imshow(np.array(output[0]))
            else:
                fig, ax = plt.subplots(1, len(output), figsize=(3 * len(output), 3))
                for i, image in enumerate(output):
                    image = np.array(image)
                    ax[i].imshow(image)
                    ax[i].set_title(f'Retrieval #{i+1}')
            plt.show()
        elif type(output) == Image.Image:
            plt.figure(figsize=(3, 3))
            plt.imshow(np.array(output))
            plt.show()


def get_image_from_path(image_path: str):
    img = Image.open(image_path)
    img = img.resize((224, 224))
    img = img.convert('RGB')
    return img


In [19]:
import argparse
import json
import os
import sys
import time
import warnings
import glob
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import pickle as pkl
from collections import namedtuple
from transformers import AutoTokenizer
from transformers import OPTForCausalLM, GPT2Tokenizer
from transformers import LlamaForCausalLM, LlamaTokenizer
from fromage.models import MCL

def load_model(model_dir: str,Llama=False):
  model_args_path = os.path.join(model_dir, 'model_args.json')
  model_ckpt_path = os.path.join(model_dir, 'pretrained_ckpt_10.pth.tar')



  if not os.path.exists(model_args_path):
    raise ValueError(f'model_args.json does not exist in {model_dir}.')
  if not os.path.exists(model_ckpt_path):
    raise ValueError(f'pretrained_ckpt.pth.tar does not exist in {model_dir}.')


  with open(model_args_path, 'r') as f:
      model_kwargs = json.load(f)

  LenRET = 1
  if 'LenRET' in model_kwargs.keys():
    LenRET = model_kwargs['LenRET']
  # Initialize tokenizer.
  if Llama:
    tokenizer = LlamaTokenizer.from_pretrained("fromage/llama/hugging-llama-2-7b")
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})
  else:
    tokenizer = GPT2Tokenizer.from_pretrained(model_kwargs['opt_version'],local_files_only=False)
    tokenizer.pad_token = tokenizer.eos_token

  tokenizer.add_special_tokens({"cls_token": "<|image|>"})
  if LenRET==1:
    tokenizer.add_tokens('[RET]')
    ret_token_idx = tokenizer('[RET]', add_special_tokens=False).input_ids
    assert len(ret_token_idx) == 1, ret_token_idx
    model_kwargs['retrieval_token_idx'] = ret_token_idx[0]
  else:
    model_kwargs['retrieval_token_idx'] = []
    for i in range(LenRET):
      RET_Token = f'[RET{i}]'
      tokenizer.add_tokens(RET_Token)
      ret_token_idx = tokenizer(f'[RET{i}]', add_special_tokens=False).input_ids
      assert len(ret_token_idx) == 1, ret_token_idx
      model_kwargs['retrieval_token_idx'].append(ret_token_idx[0])


  model_kwargs['Llama']=Llama
  args = namedtuple('args', model_kwargs)(**model_kwargs)

  model = MCL(tokenizer, args)
  model = model.eval()
  model = model.bfloat16()
  model = model.cuda()

  # Load pretrained linear mappings and [RET] embeddings.
  checkpoint = torch.load(model_ckpt_path)
  model.load_state_dict(checkpoint['state_dict'], strict=False)
  with torch.no_grad():
    if LenRET==1:
      model.model.input_embeddings.weight[model.model.retrieval_token_idx, :].copy_(checkpoint['state_dict']['ret_input_embeddings.weight'].squeeze().cpu().detach())
    else:
      model.model.input_embeddings.weight[model.model.retrieval_token_idx[0]:model.model.retrieval_token_idx[-1]+1, :].copy_(checkpoint['state_dict']['ret_input_embeddings.weight'].squeeze().cpu().detach())
  logit_scale = model.model.logit_scale.exp()


  return model

In [20]:
# Load model used in the paper.
model_dir = './runs/icml_save/'

model = load_model(model_dir,True)

Using HuggingFace AutoFeatureExtractor for openai/clip-vit-large-patch14.
Using openai/clip-vit-large-patch14 for the visual model with 4 visual tokens.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Using Llama for the language model.
Freezing the LM.
Initializing embedding for the retrieval token [RET] (id = [32002, 32003, 32004, 32005, 32006]).
Restoring pretrained weights for the visual model.
Freezing the VM.


In [11]:
# from transformers import CLIPProcessor, CLIPModel,CLIPTextModel,CLIPVisionModel
# clip_tokenizer = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32",local_files_only=True)

In [6]:
from typing import Callable, List, Optional, Tuple, Union
def generate_for_cir(
    model, prompts: List, num_words: int = 0, ret_scale_factor: float = 1.0, top_p: float = 1.0, temperature: float = 0.0,
    max_num_rets: int = 1, attention_mask =None,if_mask=True):

    input_embs = []
    input_ids = []
    all_visual_embs = []
    add_bos = True

    attention_mask = []
    for i, p in enumerate(prompts):
      if type(p) == Image.Image:
        # Encode as image.
        pixel_values = utils.get_pixel_values_for_model(model.model.feature_extractor, p)
        pixel_values = pixel_values.to(device=model.model.logit_scale.device, dtype=model.model.logit_scale.dtype)
        pixel_values = pixel_values[None, ...]

        visual_embs = model.model.get_visual_embs(pixel_values, mode='cap')  # (1, n_visual_tokens, D)

        input_embs.append(visual_embs)
        all_visual_embs.append(visual_embs)
        attention_mask.append(torch.ones(visual_embs.shape[:2], dtype=torch.int64).to(visual_embs.device))
      elif type(p) == str:
        text_ids = model.model.tokenizer(p, add_special_tokens=True, return_tensors="pt").input_ids.to(model.model.logit_scale.device)

        masks_tc = model.model.tokenizer(p, add_special_tokens=True, return_tensors="pt").attention_mask.to(model.model.logit_scale.device)
        if isinstance(model.model.retrieval_token_idx,list):
            for i in range(text_ids.shape[-1]):
              if text_ids[0][i] in model.model.retrieval_token_idx:
                masks_tc[0][i]+=1
        attention_mask.append(masks_tc)
        if not add_bos:
          # Remove <bos> tag.
          text_ids = text_ids[:, 1:]
          masks_tc = masks_tc[:, 1:]
        else:
          # Only add <bos> once.
          add_bos = False

        text_embs = model.model.input_embeddings(text_ids)  # (1, T, D)

        input_embs.append(text_embs)
        input_ids.append(text_ids)
      else:
        raise ValueError(f'Input prompts should be either PIL.Image.Image or str types, got {type(p)} instead.')
    input_embs = torch.cat(input_embs, dim=1)
    input_ids = torch.cat(input_ids, dim=1)
    attention_mask = torch.cat(attention_mask,dim=1)
    
    if num_words == 0:
      generated_ids = input_ids
      if if_mask ==False:
        attention_mask = None
      outputs = model.model.lm(inputs_embeds=input_embs, use_cache=False, attention_mask=attention_mask,output_hidden_states=True)
      # Map outputs to embeddings, so we can retrieve embeddings from the [RET] tokens.
      out = []
      for x, fc in zip(model.model.args.text_emb_layers, model.model.text_hidden_fcs):
          out.append(fc(outputs.hidden_states[x]))
      embeddings = torch.stack(out, dim=-1).sum(dim=-1)

    elif num_words > 0:
      generated_ids, generated_embeddings, _ = model.model.generate(input_embs, num_words,
        temperature=temperature, top_p=top_p, ret_scale_factor=ret_scale_factor)
      embeddings = generated_embeddings[-1][:, input_embs.shape[1]:]

      # Truncate to newline.
      newline_token_id = model.model.tokenizer('\n', add_special_tokens=False).input_ids[0]
      trunc_idx = 0
      for j in range(generated_ids.shape[1]):
        if generated_ids[0, j] == newline_token_id:
          trunc_idx = j
          break
      if trunc_idx > 0:
        generated_ids = generated_ids[:, :trunc_idx]
        embeddings = embeddings[:, :trunc_idx]
    else:
      raise ValueError

    return_outputs = []
    # Find up to max_num_rets [RET] tokens, and their corresponding scores.
    if isinstance(model.model.retrieval_token_idx,list):
      all_ret_idx = [i for i, x in enumerate(generated_ids[0, :] == model.model.retrieval_token_idx[0]) if x][:max_num_rets]
    else:
      all_ret_idx = [i for i, x in enumerate(generated_ids[0, :] == model.model.retrieval_token_idx) if x][:max_num_rets]
    seen_image_idx = []  # Avoid showing the same image multiple times.

    last_ret_idx = 0
    ret_emb_ori=None
    if len(all_ret_idx) == 0:

      caption = model.model.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
      return_outputs.append(utils.truncate_caption(caption))
    else:
      ret_emb_ori = embeddings[:, all_ret_idx[0]+4:, :]

      ret_emb_ori = model.model.transformer_fusion(ret_emb_ori)

      ret_emb = torch.mean(ret_emb_ori, dim=1) # (N,D)


    return ret_emb


In [8]:
circo_image_path = './data/unlabeled2017/'
cirr_image_path = './data/'
data_dir = './data/'

embs_path = './features/clip_model_CIRCO_embeddings.pkl'

with open(embs_path, 'rb') as wf:
    train_embs_data = pkl.load(wf)
    path_array=(train_embs_data['paths'])
    emb_matrix=(train_embs_data['embeddings'])

emb_matrix = torch.tensor(emb_matrix).cuda().float()
emb_matrix = emb_matrix / emb_matrix.norm(dim=-1, keepdim=True)
path_array = [path.split('/')[-1] for path in path_array]

In [21]:


all_visual_embs = []
all_caption=[]
all_outputs =[]
ret_scale_factor=0

image_path = circo_image_path
from tqdm import tqdm
import json
import glob

with open('repos/CIRCO/annotations/val.json','r') as f:
    CIRCO_val = json.load(f)

ifdisplay = False

Topk = {}
for mapk in [5,10,25,50]:
    Topk[mapk]=[]
with torch.no_grad():
    for sample in tqdm(CIRCO_val[:]):
        reference_img_id, target_img_id, relative_caption, _, gt_img_ids, id_,_ = sample.values()
        refer_image_path = image_path + str(reference_img_id).zfill(12) + '.jpg'
        refer_idx = path_array.index(str(reference_img_id).zfill(12)+ '.jpg')

        gt_indexs = []
        gt_image_paths = []
        for id_ in gt_img_ids:
            gt_image_path = image_path + str(id_).zfill(12) + '.jpg'
            gt_image_paths.append(gt_image_path)
            index = path_array.index(str(id_).zfill(12)+ '.jpg')
            gt_indexs.append(index)
        
        inp_image = get_image_from_path(refer_image_path)

        inp_text = 'Q:'+ relative_caption+ ".\nA:it becomes a photo of [RET0][RET1][RET2][RET3][RET4]"

        prompt = [inp_image, inp_text]
        if ifdisplay:
            display_interleaved_outputs(prompt)

        ret_emb = generate_for_cir(model,prompt,if_mask=True)

        
        ret_emb/=ret_emb.norm(dim=-1,keepdim=True)
        
        model_outputs =  emb_matrix @ (ret_emb.float()).T
        model_outputs[refer_idx]=-100

        
        for mapk in [5,10,25,50]:
            value,indexs = model_outputs.squeeze().topk(mapk)
            k=0
            x=1
            for index in indexs:
                if index in gt_indexs:
                    k+=float(1/x)
                x+=1
            Topk[mapk].append(k/min(len(gt_indexs),mapk))

    for mapk in [5,10,25,50]:
        print('makp ',mapk,torch.tensor(Topk[mapk]).mean())


100%|██████████| 220/220 [01:11<00:00,  3.09it/s]

makp  5 tensor(0.1630)
makp  10 tensor(0.1662)
makp  25 tensor(0.1731)
makp  50 tensor(0.1760)





In [22]:
#102

# image_path = '/home/liwei/exp/data/unlabeled2017/'
image_path = './data/unlabeled2017/'
from tqdm import tqdm
import json
import glob
with open('./repos/CIRCO/annotations/test.json','r') as f:
    CIRCO_test = json.load(f)
    
ifdisplay = False

Topk = {}
for mapk in [5,10,25,50]:
    Topk[mapk]=[]
results_circo = {}


all_caption = []
with torch.no_grad():

    for sample in tqdm(CIRCO_test[:]):
        reference_img_id, relative_caption, shared_concept, id_ = sample.values()
        refer_image_path = image_path + str(reference_img_id).zfill(12) + '.jpg'
        refer_idx = path_array.index(str(reference_img_id).zfill(12) + '.jpg')

        inp_image = get_image_from_path(refer_image_path)

        inp_text = 'Q:'+ relative_caption + '.\nA:it becomes a photo of[RET0][RET1][RET2][RET3][RET4]'

        prompt = [inp_image, inp_text]
        if ifdisplay:
            display_interleaved_outputs(prompt)
        ret_emb = generate_for_cir(model,prompt,if_mask=True)
        ret_emb/=ret_emb.norm(dim=-1,keepdim=True)


        model_outputs =  emb_matrix @ (ret_emb.float()).T
        model_outputs[refer_idx]=-100


        similarity_all = model_outputs
        value,indexs = similarity_all.squeeze().topk(50)
        results_circo[str(id_)]=[]
        for index in indexs:
            index_circo = int(path_array[index].split('/')[-1].split('.')[0])
            results_circo[str(id_)].append(index_circo)

100%|██████████| 800/800 [04:01<00:00,  3.32it/s]


In [23]:
with open('results/circo_results.json','w') as f:
    json.dump(results_circo,f)

In [25]:
import pickle as pkl
from tqdm import tqdm
import json
with open('repos/CIRR/captions/cap.rc2.val.json','r') as f:
    cirr_eval = json.load(f)
with open('features/clip_model_NLVR2_embeddings.pkl', 'rb') as wf:
    cirr_embs_data = pkl.load(wf)
with open('repos/CIRR/image_splits/split.rc2.val.json','r') as f:
    test1=json.load(f)
embedding_cirr =cirr_embs_data['embeddings']
path_array_cirr = list(cirr_embs_data['paths'])
path_array_cirr = [path.split('/')[-1] for path in path_array_cirr]
image_path = 'data/NLVR2/dev/'
Topk = {}
ifdisplay=False
R1 = 0
R1_all=0
num_sample = 0
embedding_cirr = torch.tensor(embedding_cirr).cuda()
embedding_cirr /= embedding_cirr.norm(dim=-1,keepdim=True)

all_gallery_indexs = [path_array_cirr.index(str(gallery_img_id) + '.png') for gallery_img_id in test1.keys()]
all_gallery_embeddings = embedding_cirr[all_gallery_indexs]
fusion=False
all_caption=[]
with torch.no_grad():

    for sample in tqdm(cirr_eval[:]):
        num_sample+=1
        reference_img_id = sample['reference']
        target_img_id = sample['target_hard']
        relative_caption = sample['caption']
        gallery_img_ids = sample['img_set']['members'].copy()
        gallery_img_ids.remove(reference_img_id)
        gt_rank = gallery_img_ids.index(sample['target_hard'])

        refer_image_path = image_path + str(reference_img_id) + '.png'
        refer_idx = path_array_cirr.index(str(reference_img_id) + '.png')


        gt_image_path = image_path + str(target_img_id)+ '.png'
        gt_indexs = path_array_cirr.index(str(target_img_id)+ '.png')
        gt_rank_all = all_gallery_indexs.index(gt_indexs)
        refer_index_all = all_gallery_indexs.index(refer_idx)
        gallery_indexs = [path_array_cirr.index(str(gallery_img_id) + '.png') for gallery_img_id in gallery_img_ids]

        inp_image = get_image_from_path(refer_image_path)
        
            
        inp_text = "Q:"+ relative_caption.lower()+".\nA:it becomes a photo of [RET0][RET1][RET2][RET3][RET4]"
        prompt = [inp_image,inp_text]
        if ifdisplay:
            display_interleaved_outputs(prompt)


        emb= generate_for_cir(model,prompt,if_mask=True)

        emb/=emb.norm(dim=-1,keepdim=True)

        similarity = embedding_cirr[gallery_indexs]@emb.float().T
        _,top_index = similarity.squeeze().topk(3)
        if similarity.argmax()==gt_rank:
            R1+=1

        similarity_all = all_gallery_embeddings@emb.float().T
        _,top_index = similarity_all.squeeze().topk(3)
        similarity_all[refer_index_all]=-100
        if similarity_all.argmax()==gt_rank_all:
            R1_all+=1

    

print(R1/num_sample)
print(R1_all/num_sample)

100%|██████████| 4181/4181 [21:13<00:00,  3.28it/s]

0.6316670652953839
0.2535278641473332





In [132]:
import pickle as pkl
with open('repos/CIRR/captions/cap.rc2.test1.json','r') as f:
    cirr_test1 = json.load(f)
with open('features/clip_model_NLVR2_test1_embeddings.pkl', 'rb') as wf:
    cirr_embs_data = pkl.load(wf)
with open('repos/CIRR/image_splits/split.rc2.test1.json','r') as f:
    test1=json.load(f)
    
embedding_cirr = cirr_embs_data['embeddings']
path_array_cirr = cirr_embs_data['paths']


image_path = 'data/NLVR2/test1/'
results_subset = {"version":"rc2","metric":"recall_subset"}

results_all ={"version":"rc2","metric":"recall"}



Topk = {}
ifdisplay=False
R1 = 0
num_sample = 0
embedding_cirr = torch.tensor(embedding_cirr).cuda()
embedding_cirr /= embedding_cirr.norm(dim=-1,keepdim=True)

all_gallery_indexs = [path_array_cirr.index(str(gallery_img_id) + '.png') for gallery_img_id in test1.keys()]
fusion=False
all_caption=[]
with torch.no_grad():

    all_gallery_embeddings = embedding_cirr[all_gallery_indexs]
    for sample in tqdm(cirr_test1[:]):
        num_sample+=1
        pairid = str(sample['pairid'])
        reference_img_id = sample['reference']
        relative_caption = sample['caption']
        gallery_img_ids = sample['img_set']['members'].copy()
        gallery_img_ids.remove(reference_img_id)

        refer_image_path = image_path + str(reference_img_id) + '.png'
        refer_idx = path_array_cirr.index(str(reference_img_id) + '.png')
 
        refer_index_all = all_gallery_indexs.index(refer_idx)

        gallery_indexs = [path_array_cirr.index(str(gallery_img_id) + '.png') for gallery_img_id in gallery_img_ids]

        inp_image = get_image_from_path(refer_image_path)

        inp_text = 'Q:'+ relative_caption.lower()+'.\nA:it becomes a photo of [RET0][RET1][RET2][RET3][RET4]'
        prompt = [inp_image, inp_text]
        if ifdisplay:
            display_interleaved_outputs(prompt)

        emb= generate_for_images_and_texts_circo(model,prompt,if_mask=True)

        emb/=emb.norm(dim=-1,keepdim=True)

        similarity = embedding_cirr[gallery_indexs]@emb.float().T
        value,indexs = similarity.squeeze().topk(3)
        results_subset[pairid]=[]
        for index in indexs:
            results_subset[pairid].append(gallery_img_ids[index])
        similarity_all = all_gallery_embeddings@emb.float().T
        similarity_all[refer_index_all]=-100
        value,indexs = similarity_all.squeeze().topk(50)
        results_all[pairid]=[]
        for index in indexs:
            results_all[pairid].append(list(test1.keys())[index])


print(R1/num_sample)

with open('results/results_subset.json','w') as f:
    json.dump(results_subset,f)
with open('results/results_all.json','w') as f:
    json.dump(results_all,f)


100%|██████████| 4148/4148 [21:11<00:00,  3.26it/s]

0.0





FileNotFoundError: [Errno 2] No such file or directory: 'cirr_results/results_subset.json'