In [1]:
####
#### For this notebook to you need to further install hugginface's transformers: 
#### https://huggingface.co/docs/transformers/index
####

In [2]:
import os
import torch
import pandas as pd
import os.path as osp
from PIL import Image
from tqdm.notebook import tqdm as tqdm
from transformers import CLIPModel, CLIPProcessor
from changeit3d.in_out.language_contrastive_dataset import LanguageContrastiveDataset

os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [3]:
st_data = '../../data/shapetalk/language/shapetalk_preprocessed_public_version_0.csv'  # define your paths
top_img_dir = '../../data/shapetalk/images/full_size'

In [4]:
model_name = "openai/clip-vit-base-patch32" 
batch_size = 256
num_workers = 12
device = torch.device('cuda:0')

In [5]:
method_name = model_name.replace('/', '_')
model = CLIPModel.from_pretrained(model_name).to(device).eval()
processor = CLIPProcessor.from_pretrained(model_name)

In [6]:
# read data
df = pd.read_csv(st_data)
df = df[df.listening_split == 'test']
df.reset_index(inplace=True, drop=True)
df["target"] = df.target_uid.copy()
df["distractor_1"] = df.source_uid.copy()
print(len(df))

53341


In [7]:
# tokenize with CLIP's tokenizer ST utterances
use_clean_utters = True

if use_clean_utters:
    utterance_to_use = df.utterance_spelled
    utterance_to_use = utterance_to_use.apply(lambda x: x.replace('-er', ''))  # quick way to remove our token for -er/-est adjective -endings
    utterance_to_use = utterance_to_use.apply(lambda x: x.replace('-est', '')) 
else:
    utterance_to_use = df.utterance # (the original without spell-checking, etc.)

df.tokens_encoded = processor(text=utterance_to_use.tolist(), padding="longest")['input_ids'] # quick & compatible with LanguageContrastiveDataset
df = df.drop(columns=['tokens', 'tokens_len']) # we do not use them (drop to avoid confusion)

In [8]:
# package in dataset
def to_stimulus_func(file_name):    
    img = Image.open(osp.join(top_img_dir, file_name + '.png')).convert('RGB')    
    return processor(images=img, return_tensors="pt")['pixel_values'][0]

dataset = LanguageContrastiveDataset(df, n_distractors=1, to_stimulus_func=to_stimulus_func)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)

In [9]:
all_logits = []
with torch.no_grad():
    for batch in tqdm(dataloader):
        distractor_img = batch['stimulus'][:,0]
        target_img = batch['stimulus'][:,1]
        joint_img = torch.cat([distractor_img, target_img]).to(device)
        text = batch['tokens'].to(device)
        assert all(batch['label'] == 1) # target is last
        n_batch = len(text)
        res = model(input_ids=text, pixel_values=joint_img)        
        distractor_logits = res.logits_per_image[:n_batch,:].diagonal()
        target_logits = res.logits_per_image[n_batch:,:].diagonal()
        all_logits.append(torch.stack([distractor_logits, target_logits], 0).t().cpu())
        
        # optional
        assert torch.allclose(distractor_logits, res.logits_per_text[:,:n_batch].diagonal())
        assert torch.allclose(target_logits, res.logits_per_text[:,n_batch:].diagonal())

all_logits = torch.cat(all_logits)
guessed_correctly = (all_logits.softmax(1).argmax(1) == 1).double()
print('Accuracy', guessed_correctly.mean())

  0%|          | 0/209 [00:00<?, ?it/s]

Accuracy tensor(0.5305, dtype=torch.float64)


In [10]:
out_file = f'../../data/pretrained/listeners/all_shapetalk_classes/rs_2022/single_utter/{method_name}_not_finetuned_on_test_data.csv'
df['guessed_correctly'] = guessed_correctly.to(bool).tolist()
df.to_csv(out_file, index=False)