In [1]:
import torch

In [2]:
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

In [3]:
transform = transforms.Compose([
    transforms.ToTensor(),  # Convert images to PyTorch tensors
    transforms.Normalize((0.1307,), (0.3081,))  # Normalize with MNIST mean and std
])

In [4]:
train_dataset = torchvision.datasets.MNIST(
    root='./data',  # Where to store the dataset
    train=True,     # This is training data
    download=True,  # Download if not present
    transform=transform  # Apply transformations
)

In [5]:
test_dataset = torchvision.datasets.MNIST(
    root='./data',
    train=False,
    download=True,
    transform=transform
)

# Create data loaders for batch processing
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [6]:
# Function to display images
def show_images(images, labels):
    fig, axes = plt.subplots(2, 5, figsize=(10, 4))
    axes = axes.flatten()

    for i in range(10):
        axes[i].imshow(images[i].reshape(28, 28), cmap='gray')
        axes[i].set_title(f"Label: {labels[i]}")
        axes[i].axis('off')

    plt.tight_layout()
    plt.show()

# Get some random training images
dataiter = iter(train_loader)
images, labels = next(dataiter)

# Show images
#show_images(images[:10], labels[:10])

In [7]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


In [8]:
import app.mnist_cnn

In [9]:
model = app.mnist_cnn.CNN()

In [10]:
import app.trainer as trainer

In [11]:
optimizer = trainer.get_optimizer("sgd", model, 0.001)
loss_fn = trainer.get_loss_function("cross_entropy")

In [12]:
losses = []

for i in range(20):
    loss = trainer.train_one_epoch(model, train_loader, optimizer, loss_fn)
    if i % 2 == 0:
        print(loss)

1.9138111780955593
0.40234535055628207
0.29020818115583363
0.23905755185893476
0.2021788326161566
0.17338140687740433
0.15051468322351416
0.13234924087956199
0.11852188756677515
0.1071349586080164
