In [2]:
import torch
import numpy as np
import torchvision

In [18]:
transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((32, 32)),  # Resize images to 32x32 pixels
    torchvision.transforms.ToTensor(),        # Convert images to PyTorch tensors
    torchvision.transforms.Normalize(mean=(0.5,), std=(0.5,))  # Normalize with mean=0.5, std=0.5
])

# Load the MNIST training dataset
# - root: Directory to store/download the dataset
# - train: True for training set
# - download: True to download the dataset if not present
# - transform: Apply the defined transformations
train_dataset = torchvision.datasets.MNIST(
    root='/home/kami/Documents/datasets/',
    train=True,
    download=False,
    transform=transform
)

# Create a DataLoader to batch and shuffle the training data
# - batch_size: 64 for manageable batch processing
# - shuffle: True to randomize the data order
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=64,
    shuffle=True
)

# Optional: Load the test dataset with the same transformations
test_dataset = torchvision.datasets.MNIST(
    root='/home/kami/Documents/datasets/',
    train=False,
    download=False,
    transform=transform
)

test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=64,
    shuffle=False  # No need to shuffle test data
)

# Example: Iterate through one batch to verify the data
for images, labels in train_loader:
    print(f"Batch shape: {images.shape}")  # Should be [64, 1, 32, 32] (batch, channels, height, width)
    print(f"Labels shape: {labels.shape}")  # Should be [64]
    print(f"Image tensor min: {images.min()}, max: {images.max()}")  # Check normalization
    break  # Only print the first batch




Batch shape: torch.Size([64, 1, 32, 32])
Labels shape: torch.Size([64])
Image tensor min: -1.0, max: 1.0
