## Keyword Counts for Class Balanced Datasets

This notebook determines the number of images for each keyword in each class

In [3]:
import numpy as np
import json
import math
from collections import Counter
import sys
import os

repo_root = os.path.join(os.getcwd(), '../code')
sys.path.append(repo_root)
import utils
import cifar10

version = '7'

with open('../other_data/cifar10_keywords_unique_v{}.json'.format(version)) as f:
    cifar10_keywords = json.load(f)

cifar = cifar10.CIFAR10Data('../other_data/cifar10')

    
class_names = utils.cifar10_label_names

print('Length of cifar10 keywords {}'.format(len(cifar10_keywords)))

def compute_top_k_keywords_per_class(cifar10_keywords, k):
    keywords_per_class = {}
    
    for ii, entry in enumerate(cifar10_keywords):
        cur_keyword = entry['nn_keyword']      
        cur_label  = class_names[cifar.all_labels[ii]]
        if not cur_label in keywords_per_class:
            keywords_per_class[cur_label] = {}
        if not cur_keyword in keywords_per_class[cur_label]:
            keywords_per_class[cur_label][cur_keyword] = 0
        keywords_per_class[cur_label][cur_keyword] +=1
    
    top_k_keywords_per_class = {}
    total_keyword_counts_per_class = {}
    for label, keyword_dict in keywords_per_class.items():
        sorted_keywords = sorted(keyword_dict.items(), key=lambda x:x[1], reverse=True)
        total_keyword_counts_per_class[label] = 0
        for _, v in keyword_dict.items():
            total_keyword_counts_per_class[label] += v
        top_k_keyword_dict = {}
        for keyword, count in sorted_keywords[:k]:
            top_k_keyword_dict[keyword] = count
        top_k_keywords_per_class[label] = top_k_keyword_dict

    return top_k_keywords_per_class, total_keyword_counts_per_class
   
       
    
def compute_new_keyword_counts(new_dataset_size, top_k_keywords_per_class, 
                               total_keyword_counts_per_class, use_total_keyword_counts=False):
    '''
    top_k_keywords_per_class: dictionary from class to another dictionary. The
    inner dictionary goes from keyword to keyword count in CIFAR10
    '''
    assert new_dataset_size % 10 == 0
    num_per_class = int(new_dataset_size / 10)
    result = {}
    for label in class_names:
        new_keyword_frequencies = {}
        new_keyword_rounding_gap = {}
        new_keyword_count = {}
        total_count = 0
        total_keyword_count = 0
        for _, value in top_k_keywords_per_class[label].items():
            total_keyword_count += value
        if use_total_keyword_counts:
            cur_keyword_count = total_keyword_counts_per_class[label]
        else:
            cur_keyword_count = total_keyword_count

        for keyword, value in top_k_keywords_per_class[label].items():
            frequency = num_per_class * (value / cur_keyword_count)
            new_keyword_frequencies[keyword] = frequency
            new_keyword_rounding_gap[keyword] = frequency - math.floor(frequency)
            new_keyword_count[keyword] = int(math.floor(frequency))
            total_count += new_keyword_count[keyword]
        print(total_count, num_per_class)
        assert total_count <= num_per_class
        assert total_count >= num_per_class - len(top_k_keywords_per_class[label])
        # sort the keywords by the rounding gap
        new_keyword_rounding_gap_sorted = sorted(new_keyword_rounding_gap.items(),
                                                 key=lambda x:x[1], reverse=True)
        for ii in range(num_per_class - total_count):
            keyword = new_keyword_rounding_gap_sorted[ii][0]
            new_keyword_count[keyword] += 1
        new_keyword_count_final = {}
        for key, value in new_keyword_count.items():
            if value > 0:
                new_keyword_count_final[key] = value
        result[label] = new_keyword_count_final
    return result
    
top_k = 50
    
top_k_keywords_per_class, total_keyword_counts_per_class = compute_top_k_keywords_per_class(cifar10_keywords, top_k)

top_k_sum = 0
for _, val in top_k_keywords_per_class.items():
    for keyword in val:
        top_k_sum += val[keyword]

total_sum = 0
for _, val in total_keyword_counts_per_class.items():
    total_sum += val

print(top_k_sum / total_sum)
print(top_k_sum / 60000)


result = compute_new_keyword_counts(2000, top_k_keywords_per_class, total_keyword_counts_per_class)

with open('../other_data/keyword_counts_v{}.json'.format(version), 'w') as f:
    json.dump(result, f, indent=2)

for label in result:
    total_count = 0
    for keyword in result[label]:
        total_count += result[label][keyword]
    print('{}, {}'.format( label, total_count))
    
for label in result:
    for keyword in result[label]:
        for label2 in result:
            if label == label2:
                continue
            if keyword in result[label2]:
                print('ERROR {} appears for {} and {}'.format(keyword, label, label2))

Length of cifar10 keywords 60000
0.9959333333333333
0.9959333333333333
187 200
190 200
181 200
195 200
185 200
188 200
174 200
182 200
178 200
185 200
airplane, 200
automobile, 200
bird, 200
cat, 200
deer, 200
dog, 200
frog, 200
horse, 200
ship, 200
truck, 200
ERROR cruiser appears for automobile and ship
ERROR cruiser appears for ship and automobile


# Check that we have enough labeled images

In [6]:
with open('../other_data/cifar10.1_v4_ti_indices_per_keyword.json', 'r') as f:
    v4_indices = json.load(f)
v4_count = {}
for key, value in v4_indices.items():
    v4_count[key] = len(value)

with open('../other_data/tinyimage_good_indices_subselected_v{}.json'.format(version), 'r') as f:
    tinyimage_good_indices = json.load(f)
    
with open('../other_data/blacklist_v{}.json'.format(version), 'r') as f:
    blacklist = json.load(f)
    
for item in blacklist:
    for keyword in tinyimage_good_indices:
        if item in tinyimage_good_indices[keyword]:
            tinyimage_good_indices[keyword].remove(item)


num_total_new_keywords = 0
num_total_warnings = 0
num_total_new = 0
new_keywords = []
for label in result:
    for keyword in result[label]:
        count_new = result[label][keyword]
        if keyword in v4_count:
            count_old = v4_count[keyword]
        else:
            count_old = 0
        if keyword not in tinyimage_good_indices:
            print('keyword {} not in tinyimage good indices'.format(keyword))
            num_total_new_keywords += 1
            new_keywords.append(keyword)
            assert count_old == 0
        else:
            if count_new > len(tinyimage_good_indices[keyword]):
                print('keyword {} does not have enough tinyimage good indices'.format(keyword))
                num_total_new_keywords += 1
                new_keywords.append(keyword)
            assert count_old <= len(tinyimage_good_indices[keyword])
        if count_new != count_old:
            if count_old > count_new:
                num_total_warnings += 1
                print('{} {} {} WARNING, TOO MANY OLD'.format(keyword, count_old, count_new))
            else:
                num_total_new += 1
                print('{} {} {}'.format(keyword, count_old, count_new))
print(num_total_warnings, num_total_new, num_total_new_keywords)
print(new_keywords)

stealth_bomber 19 20
multiengine_airplane 3 4
aeroplane 0 1
hangar_queen 0 1
bird 17 18
passerine 10 11
dromaius_novaehollandiae 3 4
bird_of_passage 0 2
songbird 0 2
rhea 0 2
cock 0 2
hen 0 1
skylark 0 1
kiwi 0 1
night_bird 0 1
ratite 0 1
elephant_bird 0 1
gamecock 0 1
tabby_cat 36 37
domestic_cat 27 28
keyword felis_domesticus does not have enough tinyimage good indices
felis_domesticus 11 12
fallow_deer 13 15
rangifer_tarandus 5 6
cervus_unicolor 0 2
japanese_deer 0 2
musk_deer 0 2
barren_ground_caribou 0 1
reindeer 0 1
european_elk 0 1
brocket 0 1
pekingese 25 26
bufo_bufo 13 20
true_frog 0 3
texas_toad 0 3
anuran 0 2
barking_frog 0 2
tailed_frog 0 2
bufo_boreas 0 2
bufo_debilis 0 2
southwestern_toad 0 2
european_toad 0 1
midwife_toad 0 1
cascades_frog 0 1
leptodactylus_pentadactylus 0 1
yosemite_toad 0 1
bufo_speciosus 0 1
bufo_microscaphus 0 1
leptodactylid 0 1
rana_cascadae 0 1
true_toad 0 1
quarter_horse 15 16
lippizaner 11 12
american_saddle_horse 0 1
plantation_walking_horse 0