In [2]:
import numpy as np
import matplotlib.pyplot as plt
from numpy import linalg as LA

cifar_dir = 'cifar-10-batches-py'

def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

In [5]:
def load_cifar_data(negatives=False):
    meta_data_dict = unpickle(cifar_dir + '/batches.meta')
    cifar_labels = np.array(meta_data_dict[b'label_names'])
    
    train_data = None
    train_filenames = []
    train_labels = []
    
    for i in range(1, 6):
        train_data_dict = unpickle(cifar_dir + '/data_batch_{}'.format(i))
        train_filenames += train_data_dict[b'filenames']
        train_labels += train_data_dict[b'labels']
        if i == 1:
            train_data = train_data_dict[b'data']
        else:
            train_data = np.vstack((train_data, train_data_dict[b'data']))
    
    train_data = train_data.reshape((len(train_data), 3, 32, 32))
    if negatives:
        train_data = train_data.transpose(0, 2, 3, 1).astype(np.float32)
    else:
        train_data = np.rollaxis(train_data, 1, 4)
    train_filenames = np.array(train_filenames)
    train_labels = np.array(train_labels)
    
    test_data_dict = unpickle(cifar_dir + '/test_batch')
    test_data = test_data_dict[b'data']
    test_filenames = test_data_dict[b'filenames']
    test_labels = test_data_dict[b'labels']
    
    test_data = test_data.reshape((len(test_data), 3, 32, 32))
    if negatives:
        test_data = test_data.transpose(0, 2, 3, 1).astype(np.float32)
    else:
        test_data = np.rollaxis(test_data, 1, 4)
    test_filenames = np.array(test_filenames)
    test_labels = np.array(test_labels)
    
    return train_data, train_filenames, train_labels, test_data, test_filenames, test_labels, cifar_labels 

In [6]:
train_data, train_filenames, train_labels, test_data, test_filenames, test_labels, label_names = load_cifar_data()
train_data = np.reshape(train_data, (50000, 3072))
test_data = np.reshape(test_data, (10000, 3072))
print("Train data: ", train_data.shape)
print("Train filenames: ", train_filenames.shape)
print("Train labels: ", train_labels.shape)
print("Test data: ", test_data.shape)
print("Test filenames: ", test_filenames.shape)
print("Test labels: ", test_labels.shape)
print("Label names: ", label_names.shape)

Train data:  (50000, 3072)
Train filenames:  (50000,)
Train labels:  (50000,)
Test data:  (10000, 3072)
Test filenames:  (10000,)
Test labels:  (10000,)
Label names:  (10,)
