This script is used to create "negative" captions for images from the Flickr30K test dataset using MLM. Steps are roughly as follows:
### 1. Identification of objects to mask and replace in captions
 * We use nltk parsing to extract noun phrases or objects that are candidates for substitution in the captions for the different images

### 2. Apply MLM to get candidate substitution phrases for masked words
 * We apply BertForMaskedLM on masked sentences

### 3. Analysis of candidate "negative" captions
  * Not all replacement words/ phrases are appropriate. We apply some post-processing to select reasonable candidates from the MLM

## Inital Set-up

In [1]:
!pip install transformers
!pip install nltk

Collecting transformers
  Downloading transformers-4.19.0-py3-none-any.whl (4.2 MB)
[K     |████████████████████████████████| 4.2 MB 5.2 MB/s 
Collecting huggingface-hub<1.0,>=0.1.0
  Downloading huggingface_hub-0.6.0-py3-none-any.whl (84 kB)
[K     |████████████████████████████████| 84 kB 1.7 MB/s 
[?25hCollecting tokenizers!=0.11.3,<0.13,>=0.11.1
  Downloading tokenizers-0.12.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.6 MB)
[K     |████████████████████████████████| 6.6 MB 27.9 MB/s 
Collecting pyyaml>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 49.4 MB/s 
Installing collected packages: pyyaml, tokenizers, huggingface-hub, transformers
  Attempting uninstall: pyyaml
    Found existing installation: PyYAML 3.13
    Uninstalling PyYAML-3.13:
      Successfully uninstalled PyYAML-3.13
Successfully installed huggingface-hub-0.

In [2]:
import os, sys
import pandas as pd
import numpy as np
import json
import spacy
import en_core_web_sm
import nltk
from nltk import word_tokenize
from nltk import pos_tag
from transformers import BertTokenizer, BertForMaskedLM
from torch.nn import functional as F
import torch
import re
from nltk.corpus import wordnet
from tqdm.auto import tqdm
import pickle

import logging

logging.basicConfig(level=logging.DEBUG)

pipeline_device = 0 if torch.cuda.is_available() else -1
print(f"Using device: {pipeline_device}")
model_type = "roberta-large"
k=10
output_base = f"output/masked_inst-{model_type}-{k}"

Using device: -1


In [3]:
nltk.download('averaged_perceptron_tagger')
nltk.download('punkt')
nltk.download('wordnet')

[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Unzipping taggers/averaged_perceptron_tagger.zip.
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Unzipping corpora/wordnet.zip.


True

In [4]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [5]:
%cd "/content/gdrive/MyDrive/NYU/DL Systems/final_project/"
!ls

/content/gdrive/.shortcut-targets-by-id/1mkkPzbGFRuZ52OVppGv8cKxDIdSjwPYf/final_project
 data					  flickr_eval.py
'DL Systems Final Presentation.gslides'   flickr_eval.sh
 eval_flickr.ipynb			  mdetr
'Experiment Tracker.gsheet'		  MLM_object_replacement.ipynb
'Final Presentation.gslides'		  output


In [6]:
# data_dir = '/content/gdrive/MyDrive/DL_systems/final_project/data/'
# flickr_anns = json.load(open(os.path.join(data_dir, 'final_flickr_mergedGT_test.json'),'r'))
with open("data/final_flickr_mergedGT_test.json") as f:
  flickr_anns = json.load(f)

len(flickr_anns['images']), len(flickr_anns['annotations'])

(4969, 14481)

In [7]:
flickr_anns.keys()

dict_keys(['info', 'licenses', 'images', 'annotations', 'categories'])

Example of image data in annotation file:

In [8]:
flickr_anns['images'][0]

{'caption': 'Several climbers in a row are climbing the rock while the man in red watches and holds the line .',
 'dataset_name': 'flickr',
 'file_name': '1016887272.jpg',
 'height': '500',
 'id': 153901,
 'original_img_id': 1016887272,
 'sentence_id': 0,
 'tokens_negative': [[0, 97]],
 'tokens_positive_eval': [[[0, 16]],
  [[39, 47]],
  [[54, 61]],
  [[65, 68]],
  [[87, 95]]],
 'width': '333'}

In [9]:
flickr_anns['images'][5]

{'caption': 'Two children playing on the beach in the sand with the ocean in the background',
 'dataset_name': 'flickr',
 'file_name': '7162685234.jpg',
 'height': '333',
 'id': 153906,
 'original_img_id': 7162685234,
 'sentence_id': 0,
 'tokens_negative': [[0, 78]],
 'tokens_positive_eval': [[[0, 12]], [[24, 33]], [[37, 45]], [[51, 60]]],
 'width': '500'}

Example of annotation data in annotation file:

In [10]:
flickr_anns['annotations'][0]

{'area': 21186.0,
 'bbox': [74.0, 302.0, 107.0, 198.0],
 'category_id': 1,
 'id': 430844,
 'image_id': 153901,
 'iscrowd': 0,
 'phrase_ids': 3,
 'tokens_positive': [[54, 61]]}

## Identification of objects to mask and replace in captions

In [11]:
article_list = ['the', 'a', 'an']
def parse_NP(sentence):
    '''
    Function for extracting noun phrases from sentences (these will be objects and object + modifiers we will be replacing to create negative captions)
    '''

    #Define grammar for parsing tree
    grammar = """NP: {<DT>?<JJ>*<NN.*>+}
                    RELATION: {<V.*>}
                                {<DT>?<JJ>*<NN.*>+}
                    ENTITY: {<NN.*>}"""
    
    parser = nltk.RegexpParser(grammar)
    NP_list = [' '.join(leaf[0] for leaf in tree.leaves()) for tree in parser.parse(nltk.pos_tag(nltk.word_tokenize(sentence))).subtrees() if tree.label() =='NP']
    
    #remove articles from NP
    NP_list = [' '.join([w for w in str.split(np,' ') if str.lower(w) not in(article_list)]) for np in NP_list]
    return NP_list


In [12]:
'''Update annotations dictionary (image keys specificially to include noun phrases and ids associated w/ them)
This is important because these NPs are what we are going to find candidate replacements for using MLM
Each caption can have multiople NPs so we need to have ids associated with each to map back to the original image/caption once we tokenize all the captions and load them in dataloader
'''
np_id = 0
for i, img in enumerate(flickr_anns['images']):
    img['NPs'] = parse_NP(img['caption'])
    img['NP_ids'] = []
    for i in range(len(img['NPs'])):
        #Add in ID for noun phrases for mapping back when we add to dataset
        img['NP_ids'].append(np_id)
        np_id += 1

In [13]:
people_words = np.array(['man', 'woman', 'men', 'women', 'children', 'child', 'girl', 'boy', 'boys', 'girls', 'father', 'fathers', 'son', 'sons', 'husband', 'husbands', 'mother','mothers', 'parent', 'parents', 'daughter', 'daughters', 'wife', 'wives',
                  'spouse', 'spouses', 'partner', 'partners', 'brother', 'brothers', 'sister', 'sisters', 'sibling', 'siblings', 'grandfather','grandfathers',	'grandmother', 'grandmothers',	'grandparents', 'uncle', 'uncles', 'aunt', 'aunts',
                   'nephew', 'nephews', 'niece', 'nieces', 'cousin', 'cousins', 'person', 'people', 'guy', 'guys', 'lady', 'ladies'])
pronoun_words = np.array(["this", "that", "these", "those", "another", "anybody", "anyone", "anything", "each", "either", "enough", "everybody", "everyone",
                "everything", "little", "much", "neither", "nobody", "no one", "nothing", "one", "other", "somebody",
                "something", "both", "few", "fewer", "many", "others", "several", "all", "any", "more", "most",
                "someone", "none", "some", "such", "who", "whom", "whose", "what", "which", 'it', 'them', 'they', 'her', 'she', 'him', 'he', 'i', 'I'])

We're going through every image and its associated caption.
Each caption has a set of noun-phrases (NPs) which we can use
as candidate words to mask out.

In [None]:
import re

MASK_TOKEN = "<mask>"

def word_count(np, caption):
    # https://stackoverflow.com/questions/17268958/finding-occurrences-of-a-word-in-a-string-in-python-3
    return sum(1 for _ in re.finditer(r'\b%s\b' % re.escape(np), caption))

def has_banned_words(NP):
    # Check if any of the people_words or pronoun_words are this word
    return NP in people_words or NP in pronoun_words

def is_valid(np, caption):
    if len(np.split(' ')) > 1:
        # logging.debug(f"IGNORE: NP '{np}' is multiple words")
        return False

    if word_count(np, caption) > 1:
        # logging.debug(f"IGNORE: Caption '{caption}' contains multiple instances of mask word '{np}'")
        return False

    if has_banned_words(np):
        # logging.debug(f"IGNORE: NP '{np}' contains banned words")
        return False

    return True


def replace_mask_word(word, caption):
    return re.sub(rf'\b{word}\b', MASK_TOKEN, caption)

cnt = 0
for image_idx, img in enumerate(flickr_anns['images']):
    caption = img['caption']
    NPs = img['NPs']
    masked_captions = []
    masked_nps = []

    for NP in NPs:
        if not is_valid(NP, caption):
            continue

        # NP has passed all tests, so lets replace the NP with the mask
        masked_caption = replace_mask_word(NP, caption)

        # Save this masked caption
        masked_captions.append(masked_caption)
        masked_nps.append(NP)

    # Store the masked captions back to the instance
    img['masked_captions'] = masked_captions
    img['masked_nps'] = masked_nps

    # Store the instance back in the dictionary
    flickr_anns['images'][image_idx] = img

masked_count = 0
for img in flickr_anns['images']:
    masked_count += len(img['masked_captions'])

print(f"# masked captions: {masked_count}")

# masked captions: 8149


### Filter out instances with no viable masked captions

In [None]:
instances = [inst for inst in flickr_anns["images"] if len(inst['masked_captions']) > 0]

print(len(instances))
instances[0]

4158


{'NP_ids': [0, 1, 2, 3, 4, 5],
 'NPs': ['Several climbers', 'row', 'rock', 'man', 'red watches', 'line'],
 'caption': 'Several climbers in a row are climbing the rock while the man in red watches and holds the line .',
 'dataset_name': 'flickr',
 'file_name': '1016887272.jpg',
 'height': '500',
 'id': 153901,
 'masked_captions': ['Several climbers in a <mask> are climbing the rock while the man in red watches and holds the line .',
  'Several climbers in a row are climbing the <mask> while the man in red watches and holds the line .',
  'Several climbers in a row are climbing the rock while the man in red watches and holds the <mask> .'],
 'masked_nps': ['row', 'rock', 'line'],
 'original_img_id': 1016887272,
 'sentence_id': 0,
 'tokens_negative': [[0, 97]],
 'tokens_positive_eval': [[[0, 16]],
  [[39, 47]],
  [[54, 61]],
  [[65, 68]],
  [[87, 95]]],
 'width': '333'}

## Apply MLM to get candidate substitution phrases for masked words

In [None]:
from transformers import pipeline
unmasker = pipeline('fill-mask', model=model_type, top_k=k, device=pipeline_device)
# unmasker(["Hello I'm a <mask> computer scientist.", "The man leans over to pick up a <mask> from the ground"])

DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /roberta-large/resolve/main/config.json HTTP/1.1" 200 0
DEBUG:filelock:Attempting to acquire lock 140003114825424 on /root/.cache/huggingface/transformers/dea67b44b38d504f2523f3ddb6acb601b23d67bee52c942da336fa1283100990.94cae8b3a8dbab1d59b9d4827f7ce79e73124efa6bb970412cd503383a95f373.lock
DEBUG:filelock:Lock 140003114825424 acquired on /root/.cache/huggingface/transformers/dea67b44b38d504f2523f3ddb6acb601b23d67bee52c942da336fa1283100990.94cae8b3a8dbab1d59b9d4827f7ce79e73124efa6bb970412cd503383a95f373.lock
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "GET /roberta-large/resolve/main/config.json HTTP/1.1" 200 482


Downloading:   0%|          | 0.00/482 [00:00<?, ?B/s]

DEBUG:filelock:Attempting to release lock 140003114825424 on /root/.cache/huggingface/transformers/dea67b44b38d504f2523f3ddb6acb601b23d67bee52c942da336fa1283100990.94cae8b3a8dbab1d59b9d4827f7ce79e73124efa6bb970412cd503383a95f373.lock
DEBUG:filelock:Lock 140003114825424 released on /root/.cache/huggingface/transformers/dea67b44b38d504f2523f3ddb6acb601b23d67bee52c942da336fa1283100990.94cae8b3a8dbab1d59b9d4827f7ce79e73124efa6bb970412cd503383a95f373.lock
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /roberta-large/resolve/main/config.json HTTP/1.1" 200 0
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /roberta-large/resolve/main/pytorch_model.bin HTTP/1.1" 302 0
DEBUG:filelock:Attempting to acquire lock 140003093625552 on /root/.cache/huggingface/transformers/8e36ec2f5052bec1e79e139b84c2c3089cb647

Downloading:   0%|          | 0.00/1.33G [00:00<?, ?B/s]

DEBUG:filelock:Attempting to release lock 140003093625552 on /root/.cache/huggingface/transformers/8e36ec2f5052bec1e79e139b84c2c3089cb647694ba0f4f634fec7b8258f7c89.c43841d8c5cd23c435408295164cda9525270aa42cd0cc9200911570c0342352.lock
DEBUG:filelock:Lock 140003093625552 released on /root/.cache/huggingface/transformers/8e36ec2f5052bec1e79e139b84c2c3089cb647694ba0f4f634fec7b8258f7c89.c43841d8c5cd23c435408295164cda9525270aa42cd0cc9200911570c0342352.lock
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /roberta-large/resolve/main/tokenizer_config.json HTTP/1.1" 404 0
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /roberta-large/resolve/main/config.json HTTP/1.1" 200 0
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggingface.c

Downloading:   0%|          | 0.00/878k [00:00<?, ?B/s]

DEBUG:filelock:Attempting to release lock 140003091176272 on /root/.cache/huggingface/transformers/7c1ba2435b05451bc3b4da073c8dec9630b22024a65f6c41053caccf2880eb8f.d67d6b367eb24ab43b08ad55e014cf254076934f71d832bbab9ad35644a375ab.lock
DEBUG:filelock:Lock 140003091176272 released on /root/.cache/huggingface/transformers/7c1ba2435b05451bc3b4da073c8dec9630b22024a65f6c41053caccf2880eb8f.d67d6b367eb24ab43b08ad55e014cf254076934f71d832bbab9ad35644a375ab.lock
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /roberta-large/resolve/main/merges.txt HTTP/1.1" 200 0
DEBUG:filelock:Attempting to acquire lock 140003091201808 on /root/.cache/huggingface/transformers/20b5a00a80e27ae9accbe25672aba42ad2d4d4cb2c4b9359b50ca8e34e107d6d.5d12962c5ee615a4c803841266e9c3be9a691a924f72d395d3a6c6c81157788b.lock
DEBUG:filelock:Lock 140003091201808 acquired on /root/.cache/huggingface/transformers/20b5a00a80e27ae9accbe2567

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

DEBUG:filelock:Attempting to release lock 140003091201808 on /root/.cache/huggingface/transformers/20b5a00a80e27ae9accbe25672aba42ad2d4d4cb2c4b9359b50ca8e34e107d6d.5d12962c5ee615a4c803841266e9c3be9a691a924f72d395d3a6c6c81157788b.lock
DEBUG:filelock:Lock 140003091201808 released on /root/.cache/huggingface/transformers/20b5a00a80e27ae9accbe25672aba42ad2d4d4cb2c4b9359b50ca8e34e107d6d.5d12962c5ee615a4c803841266e9c3be9a691a924f72d395d3a6c6c81157788b.lock
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /roberta-large/resolve/main/tokenizer.json HTTP/1.1" 200 0
DEBUG:filelock:Attempting to acquire lock 140003092830416 on /root/.cache/huggingface/transformers/e16a2590deb9e6d73711d6e05bf27d832fa8c1162d807222e043ca650a556964.fc9576039592f026ad76a1c231b89aee8668488c671dfbe6616bab2ed298d730.lock
DEBUG:filelock:Lock 140003092830416 acquired on /root/.cache/huggingface/transformers/e16a2590deb9e6d73711d

Downloading:   0%|          | 0.00/1.29M [00:00<?, ?B/s]

DEBUG:filelock:Attempting to release lock 140003092830416 on /root/.cache/huggingface/transformers/e16a2590deb9e6d73711d6e05bf27d832fa8c1162d807222e043ca650a556964.fc9576039592f026ad76a1c231b89aee8668488c671dfbe6616bab2ed298d730.lock
DEBUG:filelock:Lock 140003092830416 released on /root/.cache/huggingface/transformers/e16a2590deb9e6d73711d6e05bf27d832fa8c1162d807222e043ca650a556964.fc9576039592f026ad76a1c231b89aee8668488c671dfbe6616bab2ed298d730.lock
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /roberta-large/resolve/main/added_tokens.json HTTP/1.1" 404 0
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /roberta-large/resolve/main/special_tokens_map.json HTTP/1.1" 404 0
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggi

In [None]:
unmasker("The <mask> is struggling to stay on the horse inside of the bullpen as he goes around the red , white , and blue barrel .")

[{'score': 0.14824217557907104,
  'sequence': 'The veteran is struggling to stay on the horse inside of the bullpen as he goes around the red, white, and blue barrel.',
  'token': 3142,
  'token_str': ' veteran'},
 {'score': 0.1481788605451584,
  'sequence': 'The rookie is struggling to stay on the horse inside of the bullpen as he goes around the red, white, and blue barrel.',
  'token': 4534,
  'token_str': ' rookie'},
 {'score': 0.12870417535305023,
  'sequence': 'The horse is struggling to stay on the horse inside of the bullpen as he goes around the red, white, and blue barrel.',
  'token': 5253,
  'token_str': ' horse'},
 {'score': 0.04824225232005119,
  'sequence': 'The pitcher is struggling to stay on the horse inside of the bullpen as he goes around the red, white, and blue barrel.',
  'token': 7659,
  'token_str': ' pitcher'},
 {'score': 0.028838209807872772,
  'sequence': 'The starter is struggling to stay on the horse inside of the bullpen as he goes around the red, white, 

### Run `roberta-large` on our masked instances (only run once)

In [None]:
masked_instances = []
# Apply the MLM to each masked caption
for inst in tqdm(instances):
    # Each instance has 1 or more masked captions, which can be passed directly 
    # to the unmasking pipeline
    result = unmasker(inst["masked_captions"])

    # Pipeline returns a single-nested list if we pass only 1 item, so nest it
    if not isinstance(result[0], list):
        result = [result]

    # Save the model output with this instance
    inst["pred_captions"] = result
    # Store the instance
    masked_instances.append(inst)

# Store the results for fast loading later
with open(f"{output_base}.json", mode='w') as f:
    json.dump(masked_instances, f, indent=True)

with open(f'{output_base}.pkl', 'wb') as f:
    pickle.dump(masked_instances, f)

  0%|          | 0/4158 [00:00<?, ?it/s]



## Load the `roberta-large` predictions

In [14]:
with open(f"{output_base}.pkl", mode='rb') as f:
    masked_instances = pickle.load(f)

print(len(masked_instances))
masked_instances[0]

4158


{'NP_ids': [0, 1, 2, 3, 4, 5],
 'NPs': ['Several climbers', 'row', 'rock', 'man', 'red watches', 'line'],
 'caption': 'Several climbers in a row are climbing the rock while the man in red watches and holds the line .',
 'dataset_name': 'flickr',
 'file_name': '1016887272.jpg',
 'height': '500',
 'id': 153901,
 'masked_captions': ['Several climbers in a <mask> are climbing the rock while the man in red watches and holds the line .',
  'Several climbers in a row are climbing the <mask> while the man in red watches and holds the line .',
  'Several climbers in a row are climbing the rock while the man in red watches and holds the <mask> .'],
 'masked_nps': ['row', 'rock', 'line'],
 'original_img_id': 1016887272,
 'pred_captions': [[{'score': 0.4581820070743561,
    'sequence': 'Several climbers in a group are climbing the rock while the man in red watches and holds the line.',
    'token': 333,
    'token_str': ' group'},
   {'score': 0.13606494665145874,
    'sequence': 'Several climbers

Now that we've run `roberta` on all our masked sentences, we can begin to postprocess the masked sentences to filter out bad predictions

In [16]:
import copy

if False:
    if (cand_word == add_dict['masked_word']) | \
        (cand_word in(people_words)) | \
        (cand_word in(pronoun_words)) | \
        (cand_word in list(set([w for sublist in  [[l.name() for l in syn.lemmas()] for syn in wordnet.synsets(cand_word)] for w in sublist]))) |\
        (str.find(masked_captions[idx], cand_word)>-1) |\
        (nltk.pos_tag([cand_word])[0][1]!= nltk.pos_tag([add_dict['masked_word']])[0][1]):
        pass

def is_valid_caption(pred_caption, masked_word):
    pred_word = pred_caption["token_str"].strip().lower()
    masked_word = masked_word.strip().lower()

    if pred_word == masked_word or pred_word in masked_word or masked_word in pred_word:
        # logging.debug(f"Predicted token == masked token {masked_word}")
        return False

    if pred_word in people_words:
        # logging.debug(f"Predicted word {pred_word} in people words")
        return False

    if pred_word in pronoun_words:
        # logging.debug(f"Predicted word {pred_word} in pronoun words")
        return False

    # TODO append POS of predicted word to it
    if pred_word in set([w for sublist in  [[l.name() for l in syn.lemmas()] for syn in wordnet.synsets(pred_word)] for w in sublist]):
        # logging.debug(f"Predicted word '{pred_word}' is synonym of {masked_word} in: \n\t '{pred_caption['sequence']}'\n")
        return False

    # Don't allow subwords as replacements

    #if nltk.pos_tag([pred_word])[0][1] != nltk.pos_tag([masked_word])[0][1]:
        # logging.debug(f"Predicted word '{pred_word}' is dif POS than {masked_word} in: \n\t '{pred_caption['sequence']}'")
     #   return False

    return True

def create_clean_instance(inst):
    clean_inst = copy.deepcopy(inst)
    del clean_inst["NP_ids"]
    del clean_inst["NPs"]
    del clean_inst["masked_captions"]
    del clean_inst["masked_nps"]
    del clean_inst["pred_captions"]
    clean_inst["orig_caption"] = str(clean_inst["caption"])
    clean_inst["orig_id"] = int(clean_inst["id"])
    clean_inst["orig_sentence_id"] = int(clean_inst["sentence_id"])
    clean_inst["orig_tokens_positive_eval"] = copy.deepcopy(clean_inst["tokens_positive_eval"])
    return clean_inst

def get_img_annotations(inst):
    return [x for x in flickr_anns['annotations'] if x["image_id"] == inst["id"]]

def apply_new_id(anns, image_id, annotation_id):
    for i, ann in enumerate(anns):
        ann["image_id"] = image_id
        ann["id"] = annotation_id
        annotation_id += 1
        anns[i] = ann

    return anns, annotation_id

def find_positive_eval_idx(evals, span):
    # First check for exact span matches
    for idx, eval in enumerate(evals):
        span_check = eval[0] # item is a nested list w/ one item
        if span_check[0] == span[0] and span_check[1] == span[1]:
            return idx, eval
    
    # Now check for one where the end of the span matches
    for idx, eval in enumerate(evals):
        span_check = eval[0] # item is a nested list w/ one item
        if span_check[1] == span[1]:
            return idx, eval

    # Check if the start of the span matches
    for idx, eval in enumerate(evals):
        span_check = eval[0] # item is a nested list w/ one item
        if span_check[0] == span[0]:
            return idx, eval

    # print(f"Can't find suitable span from {evals} to {span}")
    return None, None

def fix_positive_eval_pos(inst, anns):
    # Get span of the original word and find its span in the orig caption
    word = inst["masked_word"]
    caption = inst["orig_caption"]
    new_caption = inst["caption"]
    tokens_pos_eval = inst["tokens_positive_eval"]
    match = re.search(rf"\b{word}\b", caption)
    new_span = match.span()

    # Find the positive_eval_token item in inst that matches up to this
    idx, orig_span = find_positive_eval_idx(tokens_pos_eval, new_span)
    if idx is None:
        # print(f"Couldn't find original span of masked word '{word}' in '{caption}'")
        # print(inst)
        # print()
        return None, None

    inst["tokens_positive_eval_idx"] = idx

    # Swap the starting positive_eval_token item 
    tokens_pos_eval[idx] = [list(new_span)]

    start_idx = idx + 1

    # Move all positive_eval items after it to match back up
    for i, eval in enumerate(tokens_pos_eval[start_idx:]):
        # Unnest the span
        eval = eval[0]

        # Extract the substring represented by the span in the original caption
        orig_substr = caption[eval[0]:eval[1]]

        # Find this substring in the new caption
        new_span = re.search(rf"\b{orig_substr}\b", new_caption)
        if not new_span:
            # print(f"Couldn't find '{orig_substr}' in '{new_caption}'")
            return None, None

        new_span = list(new_span.span())

        # Replace this item in the eval list
        tokens_pos_eval[start_idx + i] = [new_span]

    # Re-align the annotation tokens_positive value
    for i, ann in enumerate(anns):
        # Get the index of ann in the original list
        orig_ann_idx = inst["orig_tokens_positive_eval"].index(ann["tokens_positive"])
        # The ordering of the new tokens_positive is the same,
        # so copy the NEW tokens_positive eval back into this object
        ann["tokens_positive"] = copy.deepcopy(tokens_pos_eval[orig_ann_idx])
        # Save back to the annotations list
        anns[i] = ann

    # Save the list back to the instance
    inst["tokens_positive_eval"] = tokens_pos_eval

    return inst, anns

def unroll_instance(inst, image_id, sentence_id, annotation_id):
    imgs = []
    anns = []

    annotations = get_img_annotations(inst)

    # For each masked caption, unroll the predicted captions and create a new instance
    # We'll want to store the masked caption and the masked NP alongside each instance
    # so should iterate over both lists
    for masked_caption, masked_np, pred_captions in zip(inst["masked_captions"], inst["masked_nps"], inst["pred_captions"]):
        max_score = max([x["score"] for x in pred_captions])
        
        for i, pred_caption in enumerate(pred_captions):
            # Exclude if it's an invalid caption
            if not is_valid_caption(pred_caption, masked_np):
                continue

            # Otherwise construct a new image instance with the predicted
            # sentence as the caption
            clean_inst = create_clean_instance(inst)
            clean_inst["caption"] = pred_caption["sequence"]
            clean_inst["pred_word"] = pred_caption["token_str"].strip()
            clean_inst["masked_caption"] = masked_caption
            clean_inst["masked_word"] = masked_np
            
            # Calculate the scores
            clean_inst["score_raw"] = pred_caption["score"]
            clean_inst["score_scaled"] = pred_caption["score"] / max_score
            clean_inst["score_k"] = i

            # Apply the new IDs
            clean_inst["id"] = image_id
            clean_inst["sentence_id"] = sentence_id

            # Duplicate all associated annotations for this new instance
            new_anns = copy.deepcopy(annotations)

            # Reposition the tokens_positive_eval and tokens_positive of the image / annotation.
            # The masked word may be of different length so we need to reposition
            # the tokens_positive_eval indices
            clean_inst, new_anns = fix_positive_eval_pos(clean_inst, new_anns)

            if not clean_inst:
                # Don't save the instance if it fails the previous checks
                continue

            new_anns, annotation_id = apply_new_id(new_anns, image_id, annotation_id)

            # Save the instance
            image_id += 1
            sentence_id += 1
            anns.extend(new_anns)
            imgs.append(clean_inst)

    return imgs, anns, image_id, sentence_id, annotation_id

unroll_instance(masked_instances[471], 0, 0, 0)
# len(unroll_instance(masked_instances[0])), unroll_instance(masked_instances[0])

([], [], 0, 0, 0)

In [17]:
filtered_instances = []
new_annotations = []

image_id = 0
sentence_id = 0
annotation_id = 0
last_orig_img_id = None
for inst in tqdm(masked_instances):
    # Reset the sentence_id if we reach a new image
    new_orig_img_id = inst["original_img_id"]
    if last_orig_img_id is None or last_orig_img_id != new_orig_img_id:
        sentence_id = 0
    last_orig_img_id = new_orig_img_id

    # Process this image into multiple versions (with different predicted captions)
    imgs, anns, image_id, sentence_id, annotation_id = unroll_instance(
        inst, image_id, sentence_id, annotation_id)
    
    filtered_instances.extend(imgs)
    new_annotations.extend(anns)

print(image_id, annotation_id)
print(len(filtered_instances), len(new_annotations))
assert image_id == len(filtered_instances)
assert sentence_id < 100 # Ensure it gets reset to 0 fairly often
assert annotation_id == len(new_annotations)
assert all([isinstance(x, dict) for x in new_annotations]) # Ensure the annotations are well-formed

  0%|          | 0/4158 [00:00<?, ?it/s]

7893 27727
7893 27727


In [18]:
import random
print(len(masked_instances), len(filtered_instances))
random.choices(filtered_instances, k=3)

4158 7893


[{'caption': 'Michael is getting some serious air while the more formally dressed group gathers in the background.',
  'dataset_name': 'flickr',
  'file_name': '411008311.jpg',
  'height': '333',
  'id': 6720,
  'masked_caption': '<mask> is getting some serious air while the more formally dressed group gathers in the background .',
  'masked_word': 'Someone',
  'orig_caption': 'Someone is getting some serious air while the more formally dressed group gathers in the background .',
  'orig_id': 158134,
  'orig_sentence_id': 1,
  'orig_tokens_positive_eval': [[[0, 7]], [[42, 81]]],
  'original_img_id': 411008311,
  'pred_word': 'Michael',
  'score_k': 5,
  'score_raw': 0.009471684694290161,
  'score_scaled': 0.1838726480996371,
  'sentence_id': 2,
  'tokens_negative': [[0, 101]],
  'tokens_positive_eval': [[[0, 7]], [[42, 81]]],
  'tokens_positive_eval_idx': 0,
  'width': '500'},
 {'caption': 'A group of people wait on the sidewalk, wearing coats.',
  'dataset_name': 'flickr',
  'file_nam

## Save the filtered masked captions to our new flickr json

In [19]:
print(len(filtered_instances), len(new_annotations))

7893 27727


In [20]:
with open("data/final_flickr_mergedGT_test.json") as f:
  flickr_anns = json.load(f)

flickr_anns["images"] = filtered_instances
flickr_anns["annotations"] = new_annotations

with open("data/flickr_test_masked.json", mode="w") as f:
    json.dump(flickr_anns, f)

### Test the shape of instances before/after is the same

In [21]:
with open("data/final_flickr_mergedGT_test.json") as f:
  flickr_anns = json.load(f)

with open("output/flickr_test_masked.json") as f:
  flickr_anns_masked = json.load(f)

flickr_anns["images"][0]

{'caption': 'Several climbers in a row are climbing the rock while the man in red watches and holds the line .',
 'dataset_name': 'flickr',
 'file_name': '1016887272.jpg',
 'height': '500',
 'id': 153901,
 'original_img_id': 1016887272,
 'sentence_id': 0,
 'tokens_negative': [[0, 97]],
 'tokens_positive_eval': [[[0, 16]],
  [[39, 47]],
  [[54, 61]],
  [[65, 68]],
  [[87, 95]]],
 'width': '333'}

In [22]:
flickr_anns_masked["images"][5]

{'caption': 'Two children playing on the beach in the USA with the ocean in the background',
 'dataset_name': 'flickr',
 'file_name': '7162685234.jpg',
 'height': '333',
 'id': 5,
 'masked_caption': 'Two children playing on the beach in the <mask> with the ocean in the background',
 'masked_word': 'sand',
 'orig_caption': 'Two children playing on the beach in the sand with the ocean in the background',
 'orig_id': 153906,
 'orig_sentence_id': 0,
 'orig_tokens_positive_eval': [[[0, 12]], [[24, 33]], [[37, 45]], [[51, 60]]],
 'original_img_id': 7162685234,
 'pred_word': 'USA',
 'score_k': 1,
 'score_raw': 0.04310312494635582,
 'score_scaled': 0.10239012625063768,
 'sentence_id': 3,
 'tokens_negative': [[0, 78]],
 'tokens_positive_eval': [[[0, 12]], [[24, 33]], [[41, 45]], [[50, 59]]],
 'tokens_positive_eval_idx': 2,
 'width': '500'}

## Analysis of candidate "negative" captions
### Rules for candidates we will exclude:


1.   Synonym of masked word
2.   We will not be using man/ woman or other people words for replacement
3.   If the word is already in the sentence
4.   Is the masked word itself



In [None]:
candidate_replacements['top_replacement'] = ''
for i in range(10):
  candidate_replacements['top_replacement'] = np.where((candidate_replacements['top_replacement'] == '') & (candidate_replacements[f'Replacement Word {i}'] !=''),
                                                  candidate_replacements[f'Replacement Word {i}'],   candidate_replacements['top_replacement'])  
candidate_replacements[candidate_replacements['top_replacement']!=''][['original_caption', 'masked_caption', 'top_replacement']]    

Unnamed: 0,original_caption,masked_caption,top_replacement
0,"Two young , wet boys playing in the sand on a ...","Two young , wet boys playing in the [MASK] on ...",nfl
1,A topless mannequin with a white skirt is bein...,A topless mannequin with a white skirt is bein...,broadway
2,"A man in a blue shirt , jumping down a hill in...","A man in a blue shirt , jumping down a hill in...",mic
3,A clown is sitting cross-legged on a folding c...,A [MASK] is sitting cross-legged on a folding ...,|
4,Two blonds are passing out fliers and balloons...,Two blonds are passing out fliers and [MASK] o...,gifts
...,...,...,...
571,A man in a nun outfit has a cigarette in his m...,A man in a nun outfit has a cigarette in his [...,mice
572,"A little boy , who 's face is painted like a z...","A little boy , who 's [MASK] is painted like a...",##front
573,A black man with a hat and shades frowning wit...,A black man with a hat and [MASK] frowning wit...,frames
574,a man works outdoors with some machinery,a man works [MASK] with some machinery,chairs


In [None]:
nltk.pos_tag(['walls'])
nltk.pos_tag(['rocks'])

[('rocks', 'NNS')]

In [None]:
top_10_words_collector = []
top_10_scores_collector  = torch.empty((0,10))
candidate_replacements = pd.DataFrame()
model = model.to(device)
legit_cand_tracker = 0
for i, batch in enumerate(masked_dataloader):
    input_ids = batch[0].to(device)
    attn_mask = batch[1].to(device)
    mask_index = torch.where(input_ids == tokenizer.mask_token_id)[1]
    output = model(input_ids, attn_mask)
    logits = output.logits
    softmax = F.softmax(logits, dim = -1)
    mask_words = softmax[torch.arange(input_ids.shape[0]), mask_index, :]  
    top_10_scores, top_10_tokens = torch.topk(mask_words, 10,dim =1, sorted = True)
    for j in range(top_10_tokens.shape[0]):
        batch_words = []
        idx = i*input_ids.shape[0] + j
        add_dict = {}
        add_dict['original_caption'] = original_captions[idx]
        add_dict['masked_caption'] = masked_captions[idx]
        try:
          mask_idx = str.split(masked_captions[idx],' ').index('[MASK]')
        except:
          continue
        add_dict['masked_word'] = str.split(original_captions[idx], ' ')[mask_idx]
        add_dict['NP_idx'] = mask_idx
        add_dict['img_file'] = [img['file_name'] for img in flickr_anns['images'] if mask_idx in(img['NP_ids'])]
        none_count = 0
        for k in range(top_10_tokens.shape[1]):
            #Check conditions for replacement before adding
            #If word == masked word to replace:
            cand_word = str.replace(tokenizer.decode(top_10_tokens[j,k].detach().cpu()), ' ', '')
            if (cand_word == add_dict['masked_word']) | \
               (cand_word in(people_words)) | \
               (cand_word in(pronoun_words)) | \
               (cand_word in list(set([w for sublist in  [[l.name() for l in syn.lemmas()] for syn in wordnet.synsets(cand_word)] for w in sublist]))) |\
               (str.find(masked_captions[idx], cand_word)>-1) |\
               (nltk.pos_tag([cand_word])[0][1]!= nltk.pos_tag([add_dict['masked_word']])[0][1]):
               
              add_dict[f'Replacement Word {k}'] = None
              add_dict[f'Replacement Score {k}'] = None
              add_dict[f'New Sentence {k}'] = None
              none_count += 1
            else:
              add_dict[f'Replacement Word {k}'] = str.replace(tokenizer.decode(top_10_tokens[j,k].detach().cpu()), ' ', '')
              add_dict[f'Replacement Score {k}'] = top_10_scores[j,k].detach().cpu()
              add_dict[f'New Sentence {k}'] = str.replace(masked_captions[idx], '[MASK]', add_dict[f'Replacement Word {k}'])
        if none_count < 10:
          candidate_replacements = candidate_replacements.append(add_dict, ignore_index = True)
          legit_cand_tracker+= 1
        else:
          continue

    if i % 10 == 0:
        print(f"Running for batch: {i / len(masked_dataloader)}")
        print(add_dict)
        print(f'# sentences w/ legitimate candidates: {legit_cand_tracker}')

candidate_replacements.to_csv('output/candidate_object_replacements.csv')
candidate_replacements

Running for batch: 0.0
{'original_caption': 'A collage of one person climbing a cliff .', 'masked_caption': 'A collage of one person climbing a [MASK] .', 'masked_word': 'cliff', 'NP_idx': 7, 'img_file': ['1016887272.jpg'], 'Replacement Word 0': None, 'Replacement Score 0': None, 'New Sentence 0': None, 'Replacement Word 1': None, 'Replacement Score 1': None, 'New Sentence 1': None, 'Replacement Word 2': None, 'Replacement Score 2': None, 'New Sentence 2': None, 'Replacement Word 3': None, 'Replacement Score 3': None, 'New Sentence 3': None, 'Replacement Word 4': 'mt', 'Replacement Score 4': tensor(0.0101), 'New Sentence 4': 'A collage of one person climbing a mt .', 'Replacement Word 5': 'chicago', 'Replacement Score 5': tensor(0.0101), 'New Sentence 5': 'A collage of one person climbing a chicago .', 'Replacement Word 6': 'seattle', 'Replacement Score 6': tensor(0.0092), 'New Sentence 6': 'A collage of one person climbing a seattle .', 'Replacement Word 7': 'antarctica', 'Replacement

Unnamed: 0,original_caption,masked_caption,masked_word,NP_idx,img_file,Replacement Word 0,Replacement Score 0,New Sentence 0,Replacement Word 1,Replacement Score 1,...,New Sentence 6,Replacement Word 7,Replacement Score 7,New Sentence 7,Replacement Word 8,Replacement Score 8,New Sentence 8,Replacement Word 9,Replacement Score 9,New Sentence 9
0,A group of people are rock climbing on a rock ...,A [MASK] of people are rock climbing on a rock...,group,1.0,[1016887272.jpg],,,,##ist,tensor(0.0049),...,,##ista,tensor(0.0010),A ##ista of people are rock climbing on a rock...,,,,,,
1,A collage of one person climbing a cliff .,A collage of one person climbing a [MASK] .,cliff,7.0,[1016887272.jpg],,,,,,...,A collage of one person climbing a seattle .,antarctica,tensor(0.0088),A collage of one person climbing a antarctica .,,,,disneyland,tensor(0.0085),A collage of one person climbing a disneyland .
2,Two children playing on the beach in the sand ...,Two children playing on the beach in the sand ...,background,14.0,[1016887272.jpg],,,,,,...,,,,,,,,,,
3,A smiling clown is sitting outdoors by his booth,A smiling clown is sitting [MASK] by his booth,outdoors,5.0,[1016887272.jpg],onions,tensor(0.2679),A smiling clown is sitting onions by his booth,potatoes,tensor(0.1327),...,A smiling clown is sitting mushrooms by his booth,,,,tomatoes,tensor(0.0210),A smiling clown is sitting tomatoes by his booth,grapes,tensor(0.0204),A smiling clown is sitting grapes by his booth
4,A blond woman in a crowded area hands out flie...,A blond woman in a crowded area hands out [MAS...,fliers,9.0,[1016887272.jpg],,,,,,...,A blond woman in a crowded area hands out dogs...,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
123,a girl in a red shirt looks at her mom in a cr...,a girl in a red shirt looks at her [MASK] in a...,mom,9.0,[1016887272.jpg],,,,,,...,,christmas,tensor(0.0125),a girl in a red shirt looks at her christmas i...,,,,,,
124,A group of people with their backs to the came...,A group of people with their [MASK] to the cam...,backs,6.0,[1016887272.jpg],boots,tensor(0.5741),A group of people with their boots to the came...,,,...,A group of people with their sneakers to the c...,,,,socks,tensor(0.0067),A group of people with their socks to the came...,jackets,tensor(0.0066),A group of people with their jackets to the ca...
125,A little girl looking into a females face in t...,A little girl looking into a females face in t...,midst,10.0,[1016887272.jpg],,,,,,...,,##front,tensor(2.8356e-06),A little girl looking into a females face in t...,,,,,,
126,A referee wearing a black and white uniform wa...,A referee wearing a black and white uniform wa...,ice,18.0,[1016887272.jpg],,,,,,...,,harley,tensor(0.0025),A referee wearing a black and white uniform wa...,,,,,,


#MLM Object Replacement (Sara's)
This script is used to create "negative" captions for images from the Flickr30K test dataset using MLM. Steps are roughly as follows:
### 1. Identification of objects to mask and replace in captions
 * We use nltk parsing to extract noun phrases or objects that are candidates for substitution in the captions for the different images

### 2. Apply MLM to get candidate substitution phrases for masked words
 * We apply BertForMaskedLM on masked sentences

### 3. Analysis of candidate "negative" captions
  * Not all replacement words/ phrases are appropriate. We apply some post-processing to select reasonable candidates from the MLM

## Inital Set-up

In [None]:
!pip install transformers
!pip install nltk

Collecting transformers
  Downloading transformers-4.18.0-py3-none-any.whl (4.0 MB)
[K     |████████████████████████████████| 4.0 MB 34.6 MB/s 
[?25hCollecting sacremoses
  Downloading sacremoses-0.0.53.tar.gz (880 kB)
[K     |████████████████████████████████| 880 kB 77.3 MB/s 
Collecting huggingface-hub<1.0,>=0.1.0
  Downloading huggingface_hub-0.5.1-py3-none-any.whl (77 kB)
[K     |████████████████████████████████| 77 kB 8.2 MB/s 
Collecting tokenizers!=0.11.3,<0.13,>=0.11.1
  Downloading tokenizers-0.12.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.6 MB)
[K     |████████████████████████████████| 6.6 MB 63.2 MB/s 
[?25hCollecting pyyaml>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 85.8 MB/s 
Building wheels for collected packages: sacremoses
  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone
  Created wheel fo

In [None]:
import os, sys
import pandas as pd
import numpy as np
import json
import spacy
import en_core_web_sm
import nltk
from nltk import word_tokenize
from nltk import pos_tag
from transformers import BertTokenizer, BertForMaskedLM
from torch.nn import functional as F
import torch
import re
from nltk.corpus import wordnet

In [None]:
nltk.download('averaged_perceptron_tagger')
nltk.download('punkt')
nltk.download('wordnet')

[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Unzipping taggers/averaged_perceptron_tagger.zip.
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Unzipping corpora/wordnet.zip.


True

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [None]:
data_dir = '/content/gdrive/MyDrive/DL_systems/final_project/data/'
flickr_anns = json.load(open(os.path.join(data_dir, 'final_flickr_mergedGT_test.json'),'r'))
len(flickr_anns['images'])
len(flickr_anns['annotations'])

14481

Example of image data in annotation file:

In [None]:
flickr_anns['images'][0]

{'caption': 'Several climbers in a row are climbing the rock while the man in red watches and holds the line .',
 'dataset_name': 'flickr',
 'file_name': '1016887272.jpg',
 'height': '500',
 'id': 153901,
 'original_img_id': 1016887272,
 'sentence_id': 0,
 'tokens_negative': [[0, 97]],
 'tokens_positive_eval': [[[0, 16]],
  [[39, 47]],
  [[54, 61]],
  [[65, 68]],
  [[87, 95]]],
 'width': '333'}

Example of annotation data in annotation file:

In [None]:
flickr_anns['annotations'][0]

{'area': 21186.0,
 'bbox': [74.0, 302.0, 107.0, 198.0],
 'category_id': 1,
 'id': 430844,
 'image_id': 153901,
 'iscrowd': 0,
 'phrase_ids': 3,
 'tokens_positive': [[54, 61]]}

## Identification of objects to mask and replace in captions

In [None]:
article_list = ['the', 'a', 'an']
def parse_NP(sentence):
    '''
    Function for extracting noun phrases from sentences (these will be objects and object + modifiers we will be replacing to create negative captions)
    '''

    #Define grammar for parsing tree
    grammar = """NP: {<DT>?<JJ>*<NN.*>+}
                    RELATION: {<V.*>}
                                {<DT>?<JJ>*<NN.*>+}
                    ENTITY: {<NN.*>}"""
    
    parser = nltk.RegexpParser(grammar)
    NP_list = [' '.join(leaf[0] for leaf in tree.leaves()) for tree in parser.parse(nltk.pos_tag(nltk.word_tokenize(sentence))).subtrees() if tree.label() =='NP']
    
    #remove articles from NP
    NP_list = [' '.join([w for w in str.split(np,' ') if str.lower(w) not in(article_list)]) for np in NP_list]
    return NP_list


In [None]:
'''Update annotations dictionary (imafge keys specificially to include noun phrases and ids assocaited w/ them)
This is important because these NPs are what we are going to find candidate replacements for using MLM
Each caption can have multiople NPs so we need to have ids associated with each to map back to the original image/caption once we tokenize all the captions and load them in dataloader
'''
np_id = 0
for i, img in enumerate(flickr_anns['images']):
    img['NPs'] = parse_NP(img['caption'])
    img['NP_ids'] = []
    for i in range(len(img['NPs'])):
        #Add in ID for noun phrases for mapping back when we add to dataset
        img['NP_ids'].append(np_id)
        np_id += 1

In [None]:
flickr_anns['images'][0]

{'NP_ids': [0, 1, 2, 3, 4, 5],
 'NPs': ['Several climbers', 'row', 'rock', 'man', 'red watches', 'line'],
 'caption': 'Several climbers in a row are climbing the rock while the man in red watches and holds the line .',
 'dataset_name': 'flickr',
 'file_name': '1016887272.jpg',
 'height': '500',
 'id': 153901,
 'original_img_id': 1016887272,
 'sentence_id': 0,
 'tokens_negative': [[0, 97]],
 'tokens_positive_eval': [[[0, 16]],
  [[39, 47]],
  [[54, 61]],
  [[65, 68]],
  [[87, 95]]],
 'width': '333'}

In [None]:
import numpy as np
people_words = np.array(['man', 'woman', 'men', 'women', 'children', 'child', 'girl', 'boy', 'boys', 'girls', 'father', 'fathers', 'son', 'sons', 'husband', 'husbands', 'mother','mothers', 'parent', 'parents', 'daughter', 'daughters', 'wife', 'wives',
                  'spouse', 'spouses', 'partner', 'partners', 'brother', 'brothers', 'sister', 'sisters', 'sibling', 'siblings', 'grandfather','grandfathers',	'grandmother', 'grandmothers',	'grandparents', 'uncle', 'uncles', 'aunt', 'aunts',
                   'nephew', 'nephews', 'niece', 'nieces', 'cousin', 'cousins', 'person', 'people'])
pronoun_words = np.array(["this", "that", "these", "those", "another", "anybody", "anyone", "anything", "each", "either", "enough", "everybody", "everyone",
                "everything", "little", "much", "neither", "nobody", "no one", "nothing", "one", "other", "somebody",
                "something", "both", "few", "fewer", "many", "others", "several", "All", "any", "more", "most",
                "someone", "none", "some", "such", "who", "whom", "whose", "what", "which", 'it', 'them', 'they', 'her', 'she', 'him', 'he', 'I'])

In [None]:
masked_captions = []
original_captions = []
for i, img in enumerate(flickr_anns['images']):
    caption = img['caption']
    for j in range(len(img['NPs'])):
        NP = img['NPs'][j]
        #print(NP)
        if str.count(caption, NP)>1: 
            continue
        elif len(np.where(np.array([str.find(NP, w) for w in people_words])>-1)[0])>0: 
          continue
        elif len(np.where(np.array([str.find(NP, w) for w in pronoun_words])>-1)[0])>0: 
          continue
        elif len(str.split(NP, ' '))>1:
          continue
        elif str.find(caption, NP)==-1:
          continue
        else:
            masked_captions.append(str.replace(caption, NP, '[MASK]'))
            original_captions.append(caption)
print(f"# masked captions: {len(masked_captions)}")

# masked captions: 7449


In [None]:
#Randomly sample 1000 masked captions we want to generate replacement phrases for (we don't need all 13,912)
np.random.seed(42)
sample_idx = np.array(np.random.choice(range(len(masked_captions)), size = 1000))
masked_captions_sample = np.array(masked_captions)[sample_idx]
original_captions_sample = np.array(original_captions)[sample_idx]
print(len(masked_captions_sample))

1000


## Apply MLM to get candidate substitution phrases for masked words

In [None]:
from transformers import BertTokenizer, BertForMaskedLM
from torch.nn import functional as F
import torch
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [None]:
def encode(data, tokenizer):
    '''
    Function for encoding captions 
    '''
    input_ids = []
    attention_mask = []
    for text in data:
        tokenized_text = tokenizer.encode_plus(text,
                                            max_length=128,
                                            add_special_tokens = True,
                                            pad_to_max_length=True,
                                            return_attention_mask=True)
        input_ids.append(tokenized_text['input_ids'])
        attention_mask.append(tokenized_text['attention_mask'])
    
    return torch.tensor(input_ids, dtype=torch.long), torch.tensor(attention_mask, dtype=torch.long)

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForMaskedLM.from_pretrained('bert-base-uncased',    return_dict = True)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
use_sample = False
if use_sample:
  input_ids, attention_mask  = encode(masked_captions_sample, tokenizer )
else:
  input_ids, attention_mask  = encode(masked_captions, tokenizer )
masked_dset = torch.utils.data.TensorDataset(input_ids, attention_mask)
masked_dataloader = torch.utils.data.DataLoader(masked_dset, batch_size = 10)
len(masked_dset)

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


7449

In [None]:
device = 'cuda:0'
top_10_words_collector = []
top_10_scores_collector  = torch.empty((0,10))
candidate_replacements = pd.DataFrame()
model = model.to(device)
legit_cand_tracker = 0
for i, batch in enumerate(masked_dataloader):
    input_ids = batch[0].to(device)
    attn_mask = batch[1].to(device)
    mask_index = torch.where(input_ids == tokenizer.mask_token_id)[1]
    output = model(input_ids, attn_mask)
    logits = output.logits
    softmax = F.softmax(logits, dim = -1)
    mask_words = softmax[torch.arange(input_ids.shape[0]), mask_index, :]  
    top_10_scores, top_10_tokens = torch.topk(mask_words, 10,dim =1, sorted = True)
    for j in range(top_10_tokens.shape[0]):
        batch_words = []
        idx = i*input_ids.shape[0] + j
        add_dict = {}
        add_dict['original_caption'] = original_captions[idx]
        add_dict['masked_caption'] = masked_captions[idx]
        try:
          mask_idx = str.split(masked_captions[idx],' ').index('[MASK]')
        except:
          continue
        add_dict['masked_word'] = str.split(original_captions[idx], ' ')[mask_idx]
        add_dict['NP_idx'] = mask_idx
        add_dict['img_file'] = [img['file_name'] for img in flickr_anns['images'] if mask_idx in(img['NP_ids'])]
        none_count = 0
        for k in range(top_10_tokens.shape[1]):
            #Check conditions for replacement before adding
            #If word == masked word to replace:
            cand_word = str.replace(tokenizer.decode(top_10_tokens[j,k].detach().cpu()), ' ', '')
            if (cand_word == add_dict['masked_word']) | \
               (cand_word in(people_words)) | \
               (cand_word in(pronoun_words)) | \
               (cand_word in list(set([w for sublist in  [[l.name() for l in syn.lemmas()] for syn in wordnet.synsets(cand_word)] for w in sublist]))) |\
               (str.find(masked_captions[idx], cand_word)>-1) |\
               (nltk.pos_tag([cand_word])[0][1]!= nltk.pos_tag([add_dict['masked_word']])[0][1]):
               
              add_dict[f'Replacement Word {k}'] = None
              add_dict[f'Replacement Score {k}'] = None
              add_dict[f'New Sentence {k}'] = None
              none_count += 1
            else:
              add_dict[f'Replacement Word {k}'] = str.replace(tokenizer.decode(top_10_tokens[j,k].detach().cpu()), ' ', '')
              add_dict[f'Replacement Score {k}'] = top_10_scores[j,k].detach().cpu()
              add_dict[f'New Sentence {k}'] = str.replace(masked_captions[idx], '[MASK]', add_dict[f'Replacement Word {k}'])
        if none_count < 10:
          candidate_replacements = candidate_replacements.append(add_dict, ignore_index = True)
          legit_cand_tracker+= 1
        else:
          continue

    if i %10 == 0:
        print(f"Running for batch: {i}")
        print(add_dict)
        print(f'# sentences w/ legitimate candidates: {legit_cand_tracker}')
    
candidate_replacements.to_csv(os.path.join(data_dir, 'candidate_object_replacements.csv'))

Running for batch: 0
{'original_caption': 'A collage of one person climbing a cliff .', 'masked_caption': 'A collage of one person climbing a [MASK] .', 'masked_word': 'cliff', 'NP_idx': 7, 'img_file': ['1016887272.jpg'], 'Replacement Word 0': None, 'Replacement Score 0': None, 'New Sentence 0': None, 'Replacement Word 1': None, 'Replacement Score 1': None, 'New Sentence 1': None, 'Replacement Word 2': None, 'Replacement Score 2': None, 'New Sentence 2': None, 'Replacement Word 3': None, 'Replacement Score 3': None, 'New Sentence 3': None, 'Replacement Word 4': None, 'Replacement Score 4': None, 'New Sentence 4': None, 'Replacement Word 5': None, 'Replacement Score 5': None, 'New Sentence 5': None, 'Replacement Word 6': None, 'Replacement Score 6': None, 'New Sentence 6': None, 'Replacement Word 7': None, 'Replacement Score 7': None, 'New Sentence 7': None, 'Replacement Word 8': None, 'Replacement Score 8': None, 'New Sentence 8': None, 'Replacement Word 9': None, 'Replacement Score 9'

## Analysis of candidate "negative" captions
### Rules for candidates we will exclude:


1.   Synonym of masked word
2.   We will not be using man/ woman or other people words for replacement
3.   If the word is already in the sentence
4.   Is the masked word itself



In [None]:
candidate_replacements.fillna('', inplace = True)

In [None]:
candidate_replacements['top_replacement'] = ''
for i in range(10):
  candidate_replacements['top_replacement'] = np.where((candidate_replacements['top_replacement'] == '') & (candidate_replacements[f'Replacement Word {i}'] !=''),
                                                  candidate_replacements[f'Replacement Word {i}'],   candidate_replacements['top_replacement'])  
candidate_replacements[candidate_replacements['top_replacement']!=''][['original_caption', 'masked_caption', 'top_replacement']]    

Unnamed: 0,original_caption,masked_caption,top_replacement
0,"Two young , wet boys playing in the sand on a ...","Two young , wet boys playing in the [MASK] on ...",nfl
1,A topless mannequin with a white skirt is bein...,A topless mannequin with a white skirt is bein...,broadway
2,"A man in a blue shirt , jumping down a hill in...","A man in a blue shirt , jumping down a hill in...",mic
3,A clown is sitting cross-legged on a folding c...,A [MASK] is sitting cross-legged on a folding ...,|
4,Two blonds are passing out fliers and balloons...,Two blonds are passing out fliers and [MASK] o...,gifts
...,...,...,...
571,A man in a nun outfit has a cigarette in his m...,A man in a nun outfit has a cigarette in his [...,mice
572,"A little boy , who 's face is painted like a z...","A little boy , who 's [MASK] is painted like a...",##front
573,A black man with a hat and shades frowning wit...,A black man with a hat and [MASK] frowning wit...,frames
574,a man works outdoors with some machinery,a man works [MASK] with some machinery,chairs


In [None]:
nltk.pos_tag(['walls'])
nltk.pos_tag(['rocks'])

[('rocks', 'NNS')]

In [None]:
candidate_replacements[candidate_replacements['top_replacement']!=''][['original_caption', 'masked_caption', 'top_replacement']]    

Unnamed: 0,original_caption,masked_caption,top_replacement
19,Several climbers in a row are climbing the roc...,Several climbers in a row are climbing the [MA...,walls
38,Seven climbers are ascending a rock face whils...,Seven climbers are ascending a rock face whils...,reins
39,Seven climbers are ascending a rock face whils...,Seven climbers are ascending a rock face whils...,reins
101,Two children playing on the beach in the sand ...,Two children playing on the beach in the [MASK...,rocks
102,Two children playing on the beach in the sand ...,Two children playing on the beach in the [MASK...,rocks
...,...,...,...
43025,A man known as Deleon speaks at a Q&A .,A man known as [MASK] speaks at a Q&A .,toys
43026,A man known as Deleon speaks at a Q&A .,A man known as [MASK] speaks at a Q&A .,toys
43027,A man known as Deleon speaks at a Q&A .,A man known as [MASK] speaks at a Q&A .,toys
43028,A man known as Deleon speaks at a Q&A .,A man known as [MASK] speaks at a Q&A .,toys
