In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

import io
import json
import math
import os
import pickle
import sys

import boto3
import botocore
import numpy as np
import pywren

repo_root = os.path.join(os.getcwd(), '../code')
sys.path.append(repo_root)

import cifar10
import utils

In [2]:
def key_exists(bucket, key):
    '''Return true if a key exists in s3 bucket'''
    client = boto3.client('s3')
    try:
        obj = client.head_object(Bucket=bucket, Key=key)
        return True
    except botocore.exceptions.ClientError as exc:
        if exc.response['Error']['Code'] != '404':
            raise
        return False

def make_key(batch_index):
    return 'tinyimages_nearest_neighbor_tmp/ti_nn_batch_{}'.format(batch_index)

def get_data_for_key(bucket, key):
    client = boto3.client('s3')
    pickle_bytes = client.get_object(Bucket=bucket, Key=key)['Body'].read()
    return pickle.loads(pickle_bytes)

def store_data_for_key(bucket, key, obj):
    client = boto3.client('s3')
    tmp = pickle.dumps(obj)
    client.put_object(Bucket=bucket, Key=key, Body=tmp)
    
def make_distance_matrix(X_test, X_train):
    D = X_test.dot(X_train.T)
    D *= -2
    D += (np.linalg.norm(X_train, axis=1)**2)[:, np.newaxis].T
    D += (np.linalg.norm(X_test, axis=1)**2)[:, np.newaxis]
    return D

def compute_nearest_neighbors_batch(b_tuple):
    batch_index, b = b_tuple
    s3 = boto3.client('s3')
    # Replace this with a bucket you have write access to and that contains the relevant files (see below)
    bucket = 'cifar-10-1'
    
    if key_exists(bucket, make_key(batch_index)):
        return get_data_for_key(bucket, make_key(batch_index))
    
    if len(b) == 0:
        res = []
        store_data_for_key(bucket, make_key(batch_index), res)
        return res
    
    if not os.path.exists('data/cifar10'):
        print('Downloading CIFAR10 data ...')
        os.mkdir('data')
        os.mkdir('data/cifar10')
        s3.download_file(bucket, 'cifar10/data_batch_1', 'data/cifar10/data_batch_1')
        s3.download_file(bucket, 'cifar10/data_batch_2', 'data/cifar10/data_batch_2')
        s3.download_file(bucket, 'cifar10/data_batch_3', 'data/cifar10/data_batch_3')
        s3.download_file(bucket, 'cifar10/data_batch_4', 'data/cifar10/data_batch_4')
        s3.download_file(bucket, 'cifar10/data_batch_5', 'data/cifar10/data_batch_5')
        s3.download_file(bucket, 'cifar10/batches.meta', 'data/cifar10/batches.meta')
        s3.download_file(bucket, 'cifar10/test_batch', 'data/cifar10/test_batch')
    cifar = cifar10.CIFAR10Data('data/cifar10')
    cifar_images = np.reshape(cifar.all_images, [60000, -1])
    dim = 32 * 32 * 3
    assert cifar_images.shape[1] == dim
    cifar_images = cifar_images.astype(np.float64)
        
    pickle_bytes = s3.get_object(Bucket=bucket, Key='tinyimage_subset_data.pickle')['Body'].read()
    tinyimages = pickle.loads(pickle_bytes)
    
    bsize = len(b)
    batch_images_list = []
    for index in b:
        tmp_vec = np.reshape(tinyimages[index], [-1])
        assert tmp_vec.shape == (dim,)
        batch_images_list.append(tmp_vec)
    batch_images = np.vstack(batch_images_list).astype(np.float64)
    assert batch_images.shape == (bsize, dim)
    
    dst_matrix = np.sqrt(make_distance_matrix(batch_images, cifar_images))
    assert dst_matrix.shape == (bsize, 60000)
    
    res = []
    k = 10
    for ii, index in enumerate(b):
        cur_dsts = dst_matrix[ii, :]
        top_indices = np.argsort(cur_dsts)
        cur_res = []
        for jj in range(k):
            cur_index = top_indices[jj]
            cur_res.append((int(cur_index), float(cur_dsts[cur_index])))
        res.append((index, cur_res))

    store_data_for_key(bucket, make_key(batch_index), res)
    return res

def split_into_batches(inputs, num_batches):
    batch_size = int(math.ceil(len(inputs) / num_batches))
    print('Batch size: {}'.format(batch_size))
    cur_start = 0
    result = []
    for ii in range(num_batches):
        cur_end = cur_start + batch_size
        cur_end = min(cur_end, len(inputs))
        result.append(inputs[cur_start : cur_end])
        cur_start += batch_size
    return result

In [3]:
# Also change the version in compute_nearest_neighbors_batch above
version_string = ''
ti_by_kw, _ = utils.load_tinyimage_subset(version_string=version_string)

img_indices = []
for kw in ti_by_kw:
    for item in ti_by_kw[kw]:
        img_indices.append(item['tinyimage_index'])
img_indices = sorted(img_indices)

Loading indices from file /Users/ludwig/research/deep_learning/tinyimages/repo/data/tinyimage_subset_indices.json
Loading image data from file /Users/ludwig/research/deep_learning/tinyimages/repo/data/tinyimage_subset_data.pickle


In [4]:
len(img_indices)

589711

In [6]:
num_to_try = 589711
num_batches = 1000
#num_to_try = 1000
#num_batches = 100

keys = img_indices[:num_to_try]

input_data_batches = []
for ii, batch in enumerate(split_into_batches(keys, num_batches)):
    input_data_batches.append((ii, batch))

Batch size: 590


In [7]:
pwex = pywren.standalone_executor(job_max_runtime=999999)
futures = pwex.map(compute_nearest_neighbors_batch, input_data_batches)

In [8]:
for f in futures:
    try:
        f.result()
    except Exception as e:
        print(e)

In [9]:
results = pywren.get_all_results(futures)

In [10]:
print('Collected {} batches of results'.format(len(results)))

Collected 1000 batches of results


In [11]:
res_dict = {}
for b in results:
    for r in b:
        res_dict[r[0]] = r[1]

In [12]:
print('Collected {} total top-k nearest neighbors'.format(len(res_dict)))

Collected 589711 total top-k nearest neighbors


In [13]:
filename = 'tinyimage_cifar10_distances_full.json'
with open(filename, 'w') as f:
    json.dump(res_dict, f, indent=2)
print('Saved to {}'.format(filename))

Saved to tinyimage_cifar10_distances_full.json


In [14]:
res_dict[69341]

[(11991, 2221.019360563973),
 (28387, 2264.6964476503263),
 (47852, 2273.062691612351),
 (39223, 2278.320653463862),
 (33869, 2296.747918253111),
 (9199, 2304.886765114505),
 (6518, 2305.2984622386753),
 (37146, 2307.6037354797277),
 (33061, 2312.686533017394),
 (29811, 2328.9849720425436)]

In [None]:
# Expected answer for index 69341
[(11991, 2221.019360563973),
 (28387, 2264.6964476503263),
 (47852, 2273.062691612351),
 (39223, 2278.320653463862),
 (33869, 2296.747918253111),
 (9199, 2304.886765114505),
 (6518, 2305.2984622386753),
 (37146, 2307.6037354797277),
 (33061, 2312.686533017394),
 (29811, 2328.9849720425436)]