In [3]:
from urllib.request import urlretrieve
from os.path import isfile, isdir
from tqdm import tqdm 
import tarfile
import numpy as np
import matplotlib.pyplot as plt
import pickle

In [4]:
cifar10_dataset_folder_path = 'cifar10-batches'

class DownloadProgress(tqdm):
    last_block = 0

    def hook(self, block_num=1, block_size=1, total_size=None):
        self.total = total_size
        self.update((block_num - self.last_block) * block_size)
        self.last_block = block_num

""" 
    check if the data (zip) file is already downloaded
    if not, download it from "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" and save as cifar-10-python.tar.gz
"""
if not isfile('cifar-10-python.tar.gz'):
    with DownloadProgress(unit='B', unit_scale=True, miniters=1, desc='CIFAR-10 Dataset') as pbar:
        urlretrieve(
            'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz',
            'cifar-10-python.tar.gz',
            pbar.hook)

if not isdir(cifar10_dataset_folder_path):
    with tarfile.open('cifar-10-python.tar.gz') as tar:
        tar.extractall()
        tar.close()

In [19]:
def get_cifar10_batch(cifar10_dataset_folder_path, batch_id):
    with open(cifar10_dataset_folder_path + '/data_batch_' + str(batch_id), mode='rb') as file:
        # note the encoding type is 'latin1'
        dic = pickle.load(file, encoding='latin1')
        return dic
         
def get_cifar10_features_labels(batch):
    features = batch['data']
    labels = batch['labels']
    return features, labels

In [20]:
batch1 = get_cifar10_batch(cifar10_dataset_folder_path, 1)
features, labels = get_cifar10_features_labels(batch1)
print('Feature Shape:')
print(features.shape)

Feature Shape:
(10000, 3072)
