# 05. Data Loading

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/gaurav-redhat/pytorch_tutorial/blob/main/05_data_loading/demo.ipynb)

---

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms

## Loading MNIST

In [None]:
# Define transforms
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Load dataset
train_data = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_data = datasets.MNIST('./data', train=False, transform=transform)

print(f'Training samples: {len(train_data)}')
print(f'Test samples: {len(test_data)}')

In [None]:
# Create DataLoader
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=64, shuffle=False)

# Get a batch
images, labels = next(iter(train_loader))
print(f'Batch images shape: {images.shape}')
print(f'Batch labels shape: {labels.shape}')

## Visualize

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, 5, figsize=(12, 5))
for i, ax in enumerate(axes.flat):
    ax.imshow(images[i].squeeze(), cmap='gray')
    ax.set_title(f'Label: {labels[i].item()}')
    ax.axis('off')
plt.tight_layout()
plt.show()