In [17]:
import numpy as np 
import torch
import torchvision.transforms as transforms
from torch.utils.data import Dataset

In [18]:
# function given from the cifar-10 download link 
def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='latin1')
    data = dict['data']
    labels = dict['labels']
    return data, labels #we just want the data(img values) and labels 

In [19]:
#extract training data from the 5 batches 
b1, l1 = unpickle(r'cifar-10-batches-py\data_batch_1')
b2, l2 = unpickle(r'cifar-10-batches-py\data_batch_2')
b3, l3 = unpickle(r'cifar-10-batches-py\data_batch_3')
b4, l4 = unpickle(r'cifar-10-batches-py\data_batch_4')
b5, l5 = unpickle(r'cifar-10-batches-py\data_batch_5')

#extract test data
test, testLabels = unpickle(r'cifar-10-batches-py\test_batch')
test.shape

(10000, 3072)

In [20]:
#merge all the training data to one big array 
x_train = np.concatenate([b1,b2,b3,b4,b5])
y_train = np.concatenate([l1,l2,l3,l4,l5])
x_train.shape, y_train.shape


((50000, 3072), (50000,))

In [21]:
x_train = x_train.reshape(len(x_train),3,32,32)
# Transpose the whole data
x_train = x_train.transpose(0,2,3,1)
print("x_train shape:", x_train.shape)

x_test = test.reshape(len(test), 3, 32, 32)
x_test = x_test.transpose(0,2,3,1)
print("x_test shape: ", x_test.shape)

y_test = np.array(testLabels)

print("y_train shape: ", y_train.shape)
print("y_test shape: ", y_test.shape)

x_train shape: (50000, 32, 32, 3)
x_test shape:  (10000, 32, 32, 3)
y_train shape:  (50000,)
y_test shape:  (10000,)


In [22]:
print(x_train[0].shape)

(32, 32, 3)


In [23]:
from PIL import Image
class CIFAR10Dataset(Dataset): 
    def __init__(self, data, labels, transform=None): 
        self.data = data
        self.labels = labels
        self.transform = transform
        self.shape = data.shape
    
    def __len__(self): 
        return len(self.data)
    
    def __getitem__(self, idx): 
        image = Image.fromarray(self.data[idx])
        label = self.labels[idx]
        
        if self.transform: 
            image = self.transform(image)
        else: 
            image = transforms.ToTensor(image)
        
        return image, label 


In [24]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))])

train_dataset = CIFAR10Dataset(x_train, y_train, transform=transform)
test_dataset = CIFAR10Dataset(x_test, y_test, transform=transform)

In [25]:
print('data shape check')
print('train set: ' + format(train_dataset.shape))
print('test set: ' + format(test_dataset.shape))
print('label numbers: ' + format(len(set(train_dataset.labels))))

data shape check
train set: (50000, 32, 32, 3)
test set: (10000, 32, 32, 3)
label numbers: 10
