In [6]:
import csv, pyodbc

# set up some constants
MDB = 'txt_db/vqa_train.db/data.mdb'
DRV = '{Microsoft Access Driver (*.mdb, *.accdb)}'

# connect to db
con = pyodbc.connect('DRIVER={};DBQ={}'.format(DRV,MDB))
cur = con.cursor()

Error: ('01000', "[01000] [unixODBC][Driver Manager]Can't open lib 'Microsoft Access Driver (*.mdb, *.accdb)' : file not found (0) (SQLDriverConnect)")

In [6]:
import lmdb
from lz4.frame import compress, decompress
import msgpack
import json
import re

# Looking at dataset structure

In [2]:
env = lmdb.open('txt_db/vqa_vg.db', readonly=False, create=True,
                                 map_size=4 * 1024**4)

In [3]:
txn = env.begin()

In [4]:
cursor = txn.cursor()

In [5]:
cursor.first()

True

In [6]:
key, value = cursor.item()

In [7]:
key.decode()

'VG_1000001'

In [8]:
q = msgpack.loads(decompress(value))

In [9]:
q

{b'answers': [{b'answer': b'orange'},
  {b'answer': b'orange'},
  {b'answer': b'orange'},
  {b'answer': b'orange'},
  {b'answer': b'orange'},
  {b'answer': b'orange'},
  {b'answer': b'orange'},
  {b'answer': b'orange'},
  {b'answer': b'orange'},
  {b'answer': b'orange'}],
 b'image_id': 150305,
 b'img_fname': b'vg_000000150305.npz',
 b'input_ids': [1327, 2942, 1110, 1103, 15398, 1113, 1103, 3208, 1200, 136],
 b'question': b'What color is the bucket on the heater?',
 b'question_id': b'VG_1000001',
 b'target': {b'labels': [2], b'scores': [1.0]},
 b'toked_question': [b'What',
  b'color',
  b'is',
  b'the',
  b'bucket',
  b'on',
  b'the',
  b'heat',
  b'@@##er',
  b'@@?']}

In [10]:
q['question'.encode('utf-8')].decode().split(" ")

['What', 'color', 'is', 'the', 'bucket', 'on', 'the', 'heater?']

# Getting words in VQA

In [66]:


def get_q_info(key, value):
    # get the id for the question
    q_id = key.decode()
    # get the row from the db for a given question
    q = msgpack.loads(decompress(value))
    # get the question as a string
    q_text= q['question'.encode('utf-8')].decode()
    # want to remove the question mark for analysis
    if q_text[-1] == '?':
        q_text = q_text[0:-1]
    # turn the actual question into an set of strings
    q_set = set(q_text.split(" "))
    q_set = set(filter(lambda x: x != '',q_set))
    return q_id, q_text, q_set

def update_dict(key, value, word_dict, word_dict_counts):
    q_id, _, q_set = get_q_info(key, value)
    for word in q_set:
        if word not in word_dict:
            word_dict[word] = [q_id]
            word_dict_counts[word] = 1
        else:
            word_dict[word].append(q_id)
            word_dict_counts[word] +=1

def get_cursor(db):
    env = lmdb.open('txt_db/{}'.format(db), readonly=False, create=True,
                                 map_size=4 * 1024**4)
    txn = env.begin()
    cursor = txn.cursor()
    return cursor
            
def enumerate_db(db, word_dict, word_dict_counts):
    print("#### DOING {} ####".format(db))
    cursor = get_cursor(db)
    for key, value in cursor:
        update_dict(key, value, word_dict, word_dict_counts)
        
def count_number_of_questions(db):
    print("#### COUNTING {} ####".format(db))
    cursor = get_cursor(db)
    count = 0
    for key, value in cursor:
        count += 1
    return count

def count_valid_questions(db, word_set):
    print("#### COUNTING VALID {} ####".format(db))
    cursor = get_cursor(db)
    count = 0
    for key, value in cursor:
        _, _, q_set = get_q_info(key, value)
        if len(q_set.intersection(word_set)) != 0:
            count += 1
    return count

In [67]:
# no'vqa_test.db'
list_of_data = ['vqa_devval.db', 'vqa_train.db', 'vqa_trainval.db', 'vqa_vg.db']
word_dict = {}
word_dict_counts = {}

In [3]:
import lmdb
from lz4.frame import compress, decompress
import msgpack

for db in list_of_data:
    enumerate_db(db, word_dict, word_dict_counts)


#### DOING vqa_devval.db ####
#### DOING vqa_train.db ####
#### DOING vqa_trainval.db ####
#### DOING vqa_vg.db ####


In [4]:
num_questions = 0
for db in list_of_data:
    num_questions += count_number_of_questions(db)
num_questions

#### COUNTING vqa_devval.db ####
#### COUNTING vqa_train.db ####
#### COUNTING vqa_trainval.db ####
#### COUNTING vqa_vg.db ####


1129238

In [7]:
with open('vqa_word_counts.json', 'w') as outfile:
    json.dump(word_dict_counts, outfile)
with open('vqa_word_questions.json', 'w') as outfile:
    json.dump(word_dict, outfile)

# Getting the diff of words in bert

In [8]:
def read_file_as_list(file_name):
    with open(file_name) as file:
        return [line.strip() for line in file]

In [9]:
bert_vocab_set = set(read_file_as_list('bert_vocab/vocab.txt'))
word_dict_counts = {}
with open('vqa_word_counts.json') as infile:
    word_dict_counts = json.load(infile)

In [10]:
def is_oov_cleaned(word, vocab):
    if word in vocab:
        return False
    if len(word) > 1 and word[-1] == 's' and word[0:-1] in vocab:
        return False
    if len(word) > 2 and word[-2] == "'" and word[0:-2] in vocab:
        return False
    return True

In [11]:
vqa_words_not_in_bert = {}
for key in word_dict_counts.keys():
    if key not in bert_vocab_set:
        vqa_words_not_in_bert[key] = word_dict_counts[key]


In [12]:
with open('vqa_words_not_in_bert.txt', 'w') as outfile:
    json.dump(vqa_words_not_in_bert, outfile)

In [13]:
vqa_words_not_in_bert

{'kites': 1490,
 'kite': 2765,
 'stickers': 242,
 'stoplight': 153,
 "person's": 2965,
 'rained': 79,
 'zebra': 3482,
 "it's": 606,
 'giraffe': 5068,
 'zebras': 4166,
 'federally-mandated': 1,
 "What's": 3828,
 'giraffes': 4184,
 'raining': 1386,
 'light.': 2,
 "phone's": 23,
 'suitcases': 617,
 "man's": 14332,
 'utensil': 1182,
 'hydrant': 3182,
 'donuts': 1441,
 'leash': 269,
 'motorhomes': 1,
 'opaque': 33,
 'cooks': 42,
 'chefs': 68,
 'blinds': 407,
 'unmade': 18,
 'futon': 15,
 "someone's": 367,
 'Frisbee': 579,
 'cabbage': 40,
 'bananas': 2388,
 'pears': 37,
 'seasoned': 15,
 'sepia': 14,
 "woman's": 6036,
 'purses': 23,
 'Asians': 9,
 'handbag': 47,
 'umbrellas': 2150,
 "child's": 1050,
 'pedestrians': 104,
 'surfboard': 2520,
 'living,': 5,
 'corner,': 7,
 'squatting': 131,
 "catcher's": 250,
 'divisible': 3,
 'servings': 35,
 'oats': 2,
 "girl's": 1948,
 'indoors': 667,
 'earrings': 172,
 "kid's": 161,
 'Glazed': 2,
 'doughnuts': 489,
 'doughnut': 391,
 'muffins': 51,
 'skateb

In [68]:
num_remaining_questions = 0
for db in list_of_data:
    num_remaining_questions += count_valid_questions(db, set(vqa_words_not_in_bert.keys()))
num_remaining_questions

#### COUNTING VALID vqa_devval.db ####
#### COUNTING VALID vqa_train.db ####
#### COUNTING VALID vqa_trainval.db ####
#### COUNTING VALID vqa_vg.db ####


211647

# Splitting Test to Val

In [15]:
vqa_full_test_words = {}
with open('vqa_words_not_in_bert.txt') as infile:
    vqa_full_test_words = json.load(infile)

In [16]:
full_word_list = list(vqa_full_test_words)
full_word_list.sort()
full_word_list

['"&"',
 '"+"',
 '"11"',
 '"1560',
 '"300"',
 '"4"',
 '"40"',
 '"41"',
 '"50"',
 '"6',
 '"6"',
 '"70"',
 '"911"',
 '"A"',
 '"A49"',
 '"AHH"',
 '"ANTIQUE"',
 '"ATM"',
 '"AVE"',
 '"African"',
 '"Ahead"',
 '"Air',
 '"Ajax"',
 '"Alive"',
 '"All',
 '"Amazing"',
 '"Asian"',
 '"Automat"',
 '"B"',
 '"BACK',
 '"BELLTOWN"',
 '"BRAT"',
 '"BREAK"',
 '"BRITISH',
 '"BURGER"',
 '"Baby',
 '"Bar"',
 '"Bet"',
 '"Big',
 '"Black',
 '"Born',
 '"Braden',
 '"Brigantine"',
 '"Brooklyn',
 '"Brown',
 '"Buildstrong"',
 '"Butt',
 '"C"',
 '"COAL',
 '"COCINA"',
 '"COMMUNITY"',
 '"California"',
 '"Casey"',
 '"Cash"',
 '"Catalunya"',
 '"Cedar"',
 '"Central"',
 '"Chicago"',
 '"D"',
 '"DB"',
 '"DELL"',
 '"DONUTS"',
 '"Dainty',
 '"Dell"',
 '"Dimple"',
 '"Do',
 '"Downtown"',
 '"E"',
 '"EASON"',
 '"EAST',
 '"ECOLIERS"',
 '"EMR"',
 '"ESPN',
 '"Easy"',
 '"Eat',
 '"El"',
 '"Evil',
 '"Ex"',
 '"Exploring',
 '"F"',
 '"FANTA"',
 '"FLORIDA"',
 '"FN437"',
 '"FOR',
 '"FRANCE"',
 '"FRIDAY"',
 '"FULLY',
 '"Field',
 '"Flash',
 '"For',

In [61]:
def get_base_words(word):
    cleaned_word = re.sub(r'[^a-zA-Z\d\s:]', '', word)
    cleaned_word = cleaned_word.strip()
    cleaned_word = cleaned_word.lower()
    if len(cleaned_word) > 3 and cleaned_word[-3:] == 'ing':
        return [cleaned_word, cleaned_word[0:-3]]
    if len(cleaned_word) > 3 and cleaned_word[-3:] == 'ers':
        return [cleaned_word, cleaned_word[0:-2], cleaned_word[0:-1], cleaned_word[0:-3]]
    if len(cleaned_word) > 2 and cleaned_word[-2:] == 'es':
        return [cleaned_word, cleaned_word[0:-2], cleaned_word[0:-1]]
    if len(cleaned_word) > 2 and cleaned_word[-2:] == 'ed':
        return [cleaned_word, cleaned_word[0:-2], cleaned_word[0:-1]]
    if len(cleaned_word) > 2 and cleaned_word[-2:] == 'er':
        return [cleaned_word, cleaned_word[0:-2], cleaned_word[0:-1]]
    if len(cleaned_word) > 1 and cleaned_word[-1] == 's':
        return [cleaned_word, cleaned_word[0:-1]]
    return [cleaned_word]
    
def check_words_against_set(words, s):
    for word in words:
        if word in s:
            return True
    return False

def add_word_to_list(word, base_words, word_list, word_set, word_dict):
    word_list.append(word)
    for w in base_words:
        word_set.add(w)
    return word_dict[word]

In [62]:
get_base_words('kites,')

['kites', 'kit', 'kite']

In [63]:
words_1 = []
words_2 = []
set1 = set()
set2 = set()
count_1 = 0
count_2 = 0
for word in full_word_list:
    base_words = get_base_words(word)
    if check_words_against_set(base_words, set1):
        count_1 += add_word_to_list(word, base_words, words_1, set1, vqa_full_test_words)
    elif check_words_against_set(base_words, set2):
        count_2 += add_word_to_list(word, base_words, words_2, set2, vqa_full_test_words)
    elif count_1 <= count_2:
        count_1 += add_word_to_list(word, base_words, words_1, set1, vqa_full_test_words)
    else:
        count_2 += add_word_to_list(word, base_words, words_2, set2, vqa_full_test_words)
print(count_1)
print(count_2)

119098
106057


In [64]:
def write_list_to_file(l, filename):
    with open(filename, 'w') as openfile:
        for x in l:
            openfile.write(x)
            openfile.write('\n')

In [65]:
write_list_to_file(words_1, 'words_test_set_1.txt')
write_list_to_file(words_2, 'words_test_set_2.txt')

In [70]:
num_test1_questions = 0
for db in list_of_data:
    num_test1_questions += count_valid_questions(db, set(words_1))
num_test1_questions

#### COUNTING VALID vqa_devval.db ####
#### COUNTING VALID vqa_train.db ####
#### COUNTING VALID vqa_trainval.db ####
#### COUNTING VALID vqa_vg.db ####


115585

# Testing stuff

In [31]:
s = "'s"
s[0:-2]

''

In [9]:
from pytorch_pretrained_bert import BertTokenizer

In [11]:
toker =  BertTokenizer.from_pretrained(
        'bert-base-cased')

The pre-trained model you are loading is a cased model but you have not set `do_lower_case` to False. We are setting `do_lower_case=False` for you but you may want to check this behavior.
100%|██████████| 213450/213450 [00:00<00:00, 918661.47B/s]


In [18]:
toker.save_vocabulary('bert_vocab')

'bert_vocab/vocab.txt'

In [18]:
"s"[0:-1] == ''

True

In [40]:
re.split(r'[^a-zA-Z\d\s:]', "ASD)(*)A(S*D)()")

['ASD', '', '', '', 'A', 'S', 'D', '', '', '']

In [39]:
re.sub(r'[^a-zA-Z\d\s:]', '', "\"kid's,")

'kids'

In [7]:
'    a.    b.  '.split(' ')

['', '', '', '', 'a.', '', '', '', 'b.', '', '']

In [17]:
"\n\nasdfasf*&)*    \n".strip()

'asdfasf*&)*'

In [17]:
"ADSSD".lower()

'adssd'

In [22]:
a = []
a += 'a'
a

['a']

In [33]:
a = set()
def addset(s, x):
    s.add(x)
addset(a, "a")
a

{'a'}