In [2]:
import os
import random
import numpy as np
import time

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data

import torchvision.transforms as transforms
import torchvision.datasets as datasets

from torchsummary import summary
import matplotlib.pyplot as plt
from PIL import Image

root = "./data"

train_data = datasets.MNIST(root=root, train=True, download=True, transform=transforms.ToTensor())
test_data = datasets.MNIST(root=root, train=False, download=True, transform=transforms.ToTensor())

100%|██████████| 9.91M/9.91M [00:00<00:00, 59.3MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.73MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 14.7MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 7.68MB/s]


In [9]:
VALID_RATIO = 0.9
n_train_examples = int(len(train_data) * VALID_RATIO)
n_valid_examples = len(train_data) - n_train_examples

# Define a helper function to get the base dataset from a potentially nested Subset
def get_base_dataset(dataset):
    while isinstance(dataset, data.Subset):
        dataset = dataset.dataset
    return dataset

# Get the base MNIST dataset for mean/std calculation
base_dataset = get_base_dataset(train_data)

train_data, valid_data = data.random_split(train_data, [n_train_examples, n_valid_examples])


# Compute mean and std for normalization using the base dataset
mean = base_dataset.data.float().mean() / 255
std = base_dataset.data.float().std() / 255


train_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

test_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])


train_data.dataset.transform = train_transforms
valid_data.dataset.transform = test_transforms
test_data.transform = test_transforms

BATCH_SIZE = 256
train_dataloader = data.DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_dataloader = data.DataLoader(valid_data, batch_size=BATCH_SIZE, shuffle=False)
test_dataloader = data.DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False)

In [5]:
class LeNet(nn.Module):
    def __init__(self, num_classes, in_channels=1, img_size=28):
        super(LeNet, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, 6, kernel_size=5, padding="same")
        self.pool1 = nn.AvgPool2d(kernel_size=2)

        self.conv2 = nn.Conv2d(6, 16, kernel_size=5, padding="same")
        self.pool2 = nn.AvgPool2d(kernel_size=2)

        self.flatten = nn.Flatten()

        if img_size == 28:
            fc_input_size = 16 * 5 * 5
        else:
            fc_input_size = 16 * 35 * 35

        self.fc1 = nn.Linear(fc_input_size, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, num_classes)

    def forward(self, x):
        x = self.pool1(torch.relu(self.conv1(x)))
        x = self.pool2(torch.relu(self.conv2(x)))
        x = self.fc1(self.flatten(x))
        x = torch.relu(x)
        x = self.fc2(x)
        x = torch.relu(x)
        x = self.fc3(x)
        return x

<torch.utils.data.dataset.Subset at 0x7b5e263132c0>