In [None]:
import os
import numpy as np
from PIL import Image
import torch
from torchvision import transforms

def load_images_and_preprocess(directory):
    """
    Helper function to load and preprocess handwritten images and their labels.

    Args:
        directory (str): Path to the directory containing image files.

    Returns:
        images (torch.Tensor): Tensor of image data (num_images, 1, 28, 28).
        labels (torch.Tensor): Tensor of labels corresponding to the images.
    """
    images = []
    labels = []

    # Define a PyTorch transformation to normalize to [-1, 1]
    transform = transforms.Compose([
        transforms.Resize((28, 28)),
        transforms.ToTensor(),  # Converts to tensor and scales [0, 255] to [0, 1]
        transforms.Normalize(mean=0.5, std=0.5)  # Normalizes [0, 1] to [-1, 1]
    ])

    for filename in os.listdir(directory):
        if filename.endswith(".png"):
            # Parse the label from the filename (e.g., '0-3-1.png' -> label = 0)
            label = int(filename.split('-')[0])
            
            # Load the image
            image_path = os.path.join(directory, filename)
            image = Image.open(image_path).convert('L')  # Convert to grayscale
            
            # Resize to 28x28 if necessary (assuming MNIST compatibility)
            image = image.resize((28, 28))
            
            # Apply transformations
            image_tensor = transform(image)
            
            # Append to list
            images.append(image_tensor)
            labels.append(label)
    
    # Stack tensors into a single tensor
    images = torch.stack(images)  # Shape: (num_images, 1, 28, 28)
    labels = torch.tensor(labels, dtype=torch.long)  # Shape: (num_images,)
    
    return images, labels

if __name__ == "__main__":
    # Directory containing handwritten images
    directory_path = "./digits".strip()
    
    try:
        # Load and preprocess images and labels
        images, labels = load_images_and_preprocess(directory_path)
        
        displayed_digits = set()

        # Check: Print labels 0-9
        for i in range(len(labels)):
            label = labels[i].item()
            if label not in displayed_digits:
                plt.imshow(images[i].squeeze(), cmap='gray')  
                plt.title(f"Label: {label}")
                plt.axis('off')
                plt.show()

                # keep track of displayed digits
                displayed_digits.add(label)
            
            # break when we've iterated through labels 0 to 9 
            if len(displayed_digits) == 10:
                break

    except Exception as e:
        print(f"Error: {e}")



Loaded 39 images.
Images shape: torch.Size([39, 1, 28, 28])
Labels shape: torch.Size([39])
First image tensor:
tensor([[[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         

In [8]:
import os
import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image

class HandwrittenDigitsDataset(Dataset):
    def __init__(self, directory):
        """
        Custom Dataset for loading handwritten digit images.

        Args:
            directory (str): Path to the directory containing image files.
        """
        self.directory = directory
        self.image_files = [f for f in os.listdir(directory) if f.endswith(".png")]
        self.transform = transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.ToTensor(),  # Convert to tensor and scale [0, 255] to [0, 1]
            transforms.Normalize(mean=0.5, std=0.5)  # Normalize to [-1, 1]
        ])
    
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        
        
        """
        Load an image and its label.

        Args:
            idx (int): Index of the image file.

        Returns:
            image (torch.Tensor): Preprocessed image tensor.
            label (int): Label extracted from the file name.
        """
        file_name = self.image_files[idx]
        label = int(file_name.split('-')[0])  # Extract label from the filename
        image_path = os.path.join(self.directory, file_name)
        
        # Load and preprocess the image
        image = Image.open(image_path).convert('L')  # Convert to grayscale
        image = self.transform(image)  # Apply transformations
        
        return image, label


if __name__ == "__main__":
    # Directory containing handwritten digit files
    directory_path = "./digits".strip()
    
    try:
        # Create the dataset and DataLoader
        dataset = HandwrittenDigitsDataset(directory_path)
        dataloader = DataLoader(dataset, batch_size=16, shuffle=True)
        
        # Iterate over the DataLoader and display batch shapes
        for batch_idx, (images, labels) in enumerate(dataloader):
            print(f"Batch {batch_idx + 1}:")
            print(f" - Images shape: {images.shape}")  # (batch_size, 1, 28, 28)
            print(f" - Labels shape: {labels.shape}")  # (batch_size,)
            print(f" - Labels: {labels.tolist()}")
            break
    except Exception as e:
        print(f"Error: {e}")


Batch 1:
 - Images shape: torch.Size([16, 1, 28, 28])
 - Labels shape: torch.Size([16])
 - Labels: [9, 4, 4, 3, 0, 3, 8, 2, 8, 9, 7, 9, 8, 1, 2, 0]
