# Sample new labeled images for the class-balanced test sets

Samples the subselected and labeled TinyImages indicies to create a class-balanced new test set.

The script requires the following files from s3 in the `other_data` directory:
* `tinyimage_large_dst_images_v6.1.json`
* `tinyimage_large_dst_image_data_v6.1.pickle`

These files can be downloaded with `other_data/download.py --all`.

Smaller files required that are checked in to the repo:
* `tinyimage_good_indices_subselected_v{}.json`
* `blacklist_v{}.json`
* `keywords_v{}.json`

In addition, CIFAR-10 dataset should be downloaded in `other_data/cifar10`

In [6]:
%load_ext autoreload
%autoreload 2

import io
import json
import math
import pickle
import random
import os
import sys
from copy import deepcopy

from IPython.display import display
from ipywidgets import widgets
import numpy as np
import PIL.Image
import tqdm

repo_root = os.path.join(os.getcwd(), '../code')
sys.path.append(repo_root)

import cifar10
import utils

cifar = cifar10.CIFAR10Data('../other_data/cifar10')
cifar_labels = cifar.all_labels

version = '7'

# Both v6 and v7 currently use the v6.1 image files
with open('../other_data/tinyimage_large_dst_images_v6.1.json', 'r') as f:
    all_new_imgs = json.load(f)
with open('../other_data/tinyimage_large_dst_image_data_v6.1.pickle', 'rb') as f:
    img_data = pickle.load(f)
with open('../other_data/tinyimage_good_indices_subselected_v{}.json'.format(version), 'r') as f:
    tinyimage_good_indices = json.load(f)
with open('../other_data/keyword_counts_v{}.json'.format(version), 'r') as f:
    keyword_counts_per_class = json.load(f)
# Blacklist contains images that are near-duplicates in CIFAR-10
with open('../other_data/blacklist_v{}.json'.format(version), 'r') as f:
    blacklist = json.load(f)

# Remove if the idx is on the blacklist 
# (the blacklist mostly contains near duplicates with CIFAR-10)
for item in blacklist:
    for keyword in tinyimage_good_indices:
        if item in tinyimage_good_indices[keyword]:
            tinyimage_good_indices[keyword].remove(item)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Sample images


In [8]:
random.seed(670725112)
dataset_size = 2000
new_data = np.empty((dataset_size, 32,32,3), float)
new_labels = np.empty(dataset_size, int)

# Disambiguates whether the cruiser indices belong to the automobile or ship class
with open('../other_data/cruiser_good_indices.json', 'r') as f:
    cruiser_good_indices = json.load(f)
    
# Both v6 and v7 use the v4 indices as a starting point
with open('../other_data/cifar10.1_v4_ti_indices_per_keyword.json', 'r') as f:
    v4_indices = json.load(f)
    

ii = 0
tiny_image_map = []
ti_indices = {}
for label in keyword_counts_per_class:
    ti_indices[label] = {}
    for keyword in keyword_counts_per_class[label]:
        if keyword == 'cruiser':
            if label == "ship": 
                existing_indices = []
            else:
                existing_indices = v4_indices[keyword]
        else:
            if keyword in v4_indices:
                existing_indices = v4_indices[keyword]
            else:
                existing_indices = []
        ti_indices[label][keyword] = existing_indices
        
        count_new = keyword_counts_per_class[label][keyword]
        if keyword in v4_indices:
            count_old = len(existing_indices)
        else:
            count_old = 0
        if count_new != count_old:
            if count_old > count_new:
                # Remove images if we have too many for this particular keyword
                num_images_to_remove = count_old - count_new
                assert len(existing_indices) >= num_images_to_remove
                # Sample new indices to remove from existing indices
                sampled_indices_to_remove = random.sample(existing_indices, num_images_to_remove)      
                for item in sampled_indices_to_remove:
                    ti_indices[label][keyword].remove(item)
            else:
                # Add images if we need more for this particular keyword
                num_new_images = count_new - count_old
                
                # Determine the set of new labeled indices to sample from 
                if keyword == "cruiser":
                    cur_good_indices = cruiser_good_indices[label]
                else:
                    cur_good_indices = tinyimage_good_indices[keyword]
                if count_old == 0:
                    # There are no existing old indices
                    ti_good_indices = cur_good_indices
                else:
                    ti_good_indices = list(set(cur_good_indices) - set(existing_indices))
                
                # Sample new indices to add
                ti_sampled_indices = random.sample(ti_good_indices, num_new_images)
                for item in ti_sampled_indices:
                    ti_indices[label][keyword].append(item)
        # Get the images and labels corresponding to the sampled indices 
        for idx in ti_indices[label][keyword]:
            tiny_image_map.append(idx)
            new_data[ii] = img_data[idx]
            new_labels[ii] = cifar.label_names.index(label)
            ii = ii+1
               
with open('../other_data/cifar10.1_v{}_ti_indices_map.json'.format(version), 'w') as f:
    json.dump(tiny_image_map, f, indent=2)
np.save('../datasets/cifar10.1_v{}_data.npy'.format(version), new_data.astype(np.uint8))
np.save('../datasets/cifar10.1_v{}_labels.npy'.format(version), new_labels.astype(np.int32))
with open('../../tiny_images_2/data/tiny_image_map_v7.json'.format(version), 'w') as f:
    json.dump(tiny_image_map, f, indent=2)
np.save('../../tiny_images_2/data/new_cifar10_data_v7.npy'.format(version), new_data.astype(np.uint8))
np.save('../../tiny_images_2/data/new_cifar10_labels_v7.npy'.format(version), new_labels.astype(np.int32))


# Check that the dataset is class balanced
label_counts = {}
for i in range(10):
    total_count = 0
    for j in range(len(new_labels)):
        if new_labels[j] == i:
            total_count +=1
    assert total_count == dataset_size / len(cifar.label_names)
            