In [1]:
from urllib.request import urlretrieve
from os.path import isfile, isdir
from tqdm import tqdm 
import tarfile
import numpy as np
import matplotlib.pyplot as plt
import pickle
import glob

In [2]:
path = 'cifar10-batches'

class DownloadProgress(tqdm):
    last_block = 0

    def hook(self, block_num=1, block_size=1, total_size=None):
        self.total = total_size
        self.update((block_num - self.last_block) * block_size)
        self.last_block = block_num

""" 
    check if the data (zip) file is already downloaded
    if not, download it from "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" and save as cifar-10-python.tar.gz
"""
if not isfile('cifar-10-python.tar.gz'):
    with DownloadProgress(unit='B', unit_scale=True, miniters=1, desc='CIFAR-10 Dataset') as pbar:
        urlretrieve(
            'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz',
            'cifar-10-python.tar.gz',
            pbar.hook)

if not isdir(path):
    with tarfile.open('cifar-10-python.tar.gz') as tar:
        tar.extractall()
        tar.close()

In [3]:
batch_files = len(glob.glob1(path,"data*"))
print(batch_files)

5


In [4]:
def get_cifar10_batch(path, batch_id):
    with open(path + '/data_batch_' + str(batch_id), mode='rb') as file:
        # note the encoding type is 'latin1'
        dic = pickle.load(file, encoding='latin1')
        return dic
         
def get_cifar10_features_labels(batch):
    features = batch['data']
    labels = batch['labels']
    labels = np.array(labels)
    return features, labels

In [5]:
batch_files = len(glob.glob1(path,"data*"))
print(batch_files)

5


In [7]:
batch1 = get_cifar10_batch(path, 1)
feature_list = []
label_list = []


for number in range(1,batch_files+1):
    batch = get_cifar10_batch(path, 1)  
    features, labels = get_cifar10_features_labels(batch)
    feature_list.append(features)
    label_list.append(labels)

feature_array = np.array(feature_list)
label_array = np.array(label_list)
print('Feature Shape:')
print(feature_array.shape)
print('Labels Shape:')
print(label_array.shape)

Feature Shape:
(5, 10000, 3072)
Labels Shape:
(5, 10000)


In [8]:
feature_array  = np.vstack(feature_array)
label_array = np.concatenate(label_array)
print('Feature Shape:')
print(feature_array.shape)
print('Labels Shape:')
print(label_array.shape)

Feature Shape:
(50000, 3072)
Labels Shape:
(50000,)
