In [1]:
import numpy as np
import struct
from PIL import Image
import os

def read_idx(filename):
    with open(filename, 'rb') as f:
        zero, data_type, dims = struct.unpack('>HBB', f.read(4))
        shape = tuple(struct.unpack('>I', f.read(4))[0] for d in range(dims))
        return np.frombuffer(f.read(), dtype=np.uint8).reshape(shape)

def save_images(images, labels, path='images', prefix='image'):
    if not os.path.exists(path):
        os.makedirs(path)
    for i, (image, label) in enumerate(zip(images, labels)):
        class_name = str(label)
        class_path = os.path.join(path, class_name)
        if not os.path.exists(class_path):
            os.makedirs(class_path)
        img = Image.fromarray(image)
        img.save(os.path.join(class_path, f'{prefix}_{i}.png'))

def check_mnist_folder(folder_path):
    images_file = os.path.join(folder_path, 'train-images.idx3-ubyte')
    labels_file = os.path.join(folder_path, 'train-labels.idx1-ubyte')
    if os.path.exists(images_file) and os.path.exists(labels_file):
        return images_file, labels_file
    else:
        return None, None

def process_mnist_folder(mnist_folder):
    images_path, labels_path = check_mnist_folder(mnist_folder)

    if images_path and labels_path:
        images = read_idx(images_path)
        labels = read_idx(labels_path)
        
        # Determine the output path within the MNIST folder
        output_path = os.path.join(mnist_folder, 'images')
        
        # Save all images
        save_images(images, labels, path=output_path)
        
        # Print all labels
        print(labels)
    else:
        print("MNIST files not found in the specified folder.")

# Path to the MNIST folder
mnist_folder = r'C:\Users\lewka\Downloads\archive'

# Process the MNIST folder
process_mnist_folder(mnist_folder)


KeyboardInterrupt

