In [1]:
import torchvision
import torchvision.transforms as transforms

# Define a transform to convert images to tensors
transform = transforms.ToTensor()

# Download the training dataset
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)

# Download the test dataset
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

print("Number of training images:", len(trainset))
print("Number of test images:", len(testset))


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ../data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [02:11<00:00, 1298468.54it/s]


Extracting ../data/cifar-10-python.tar.gz to ../data
Files already downloaded and verified
Number of training images: 50000
Number of test images: 10000


In [2]:
import pickle
from PIL import Image
import os

def unpickle(file):
    with open(file, 'rb') as fo:
        dict_data = pickle.load(fo, encoding='bytes')
    return dict_data

def save_images_from_batch(batch_file, output_folder):
    # Create output directory if it doesn't exist
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    # Load the data batch
    batch = unpickle(batch_file)
    images = batch[b'data']      # shape (10000, 3072)
    labels = batch[b'labels']    # labels if needed
    # Optional: filenames might be provided
    filenames = batch.get(b'filenames')

    # Iterate through each image in the batch
    for i in range(images.shape[0]):
        # Reshape the flat array into 3x32x32 then transpose to 32x32x3
        img_array = images[i].reshape(3, 32, 32).transpose(1, 2, 0)
        img = Image.fromarray(img_array)

        # Use the filename from the batch if available, otherwise generate one
        if filenames is not None:
            filename = filenames[i].decode('utf-8')  # Convert bytes to string
        else:
            filename = f"img_{i}.png"

        # Save the image
        img.save(os.path.join(output_folder, filename))

In [3]:
import os

data_folder = 'data/cifar-10-batches-py'
train_folder = 'data/cifar-10/train_images'
test_folder = 'data/cifar-10/test_images'

for filename in os.listdir(data_folder):
    if filename.startswith('data_batch_'):
        filepath = os.path.join(data_folder, filename)
        if os.path.isfile(filepath):
                save_images_from_batch(filepath, train_folder)
                print(filepath)
    elif filename.startswith('test_batch'):
        filepath = os.path.join(data_folder, filename)
        if os.path.isfile(filepath):
                save_images_from_batch(filepath, test_folder)
                print(filepath)         


../data/cifar-10-batches-py/data_batch_1
../data/cifar-10-batches-py/data_batch_2
../data/cifar-10-batches-py/data_batch_5
../data/cifar-10-batches-py/test_batch
../data/cifar-10-batches-py/data_batch_4
../data/cifar-10-batches-py/data_batch_3


In [4]:
# Delete the unnecessary folder

import shutil
import os

try:
    shutil.rmtree(data_folder)
    print(f"Successfully deleted the directory and all its contents: {data_folder}")
except Exception as e:
    print(f"Error: {e}")


file_path = 'data/cifar-10-python.tar.gz'
if os.path.exists(file_path):
    os.remove(file_path)
    print(f"{file_path} deleted successfully.")
else:
    print(f"{file_path} does not exist.")

Successfully deleted the directory and all its contents: ../data/cifar-10-batches-py
../data/cifar-10-python.tar.gz deleted successfully.
