In [2]:
import torch
from torch import nn
import torch
from torch import nn
from torchvision import datasets
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader

class Reshape(nn.Module):
	def forward(self, x):
		return x.view(-1, 1, 28, 28)


class LeNet5(nn.Module):
	def __init__(self):
		super(LeNet5, self).__init__()
		self.net = nn.Sequential(
			Reshape(),

			# CONV1, ReLU1, POOL1
			nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, padding=2),
			# nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5),
			nn.ReLU(),
			nn.MaxPool2d(kernel_size=2, stride=2),

			# CONV2, ReLU2, POOL2
			nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5),
			nn.ReLU(),
			nn.MaxPool2d(kernel_size=2, stride=2),
			nn.Flatten(),

			# FC1
			nn.Linear(in_features=16 * 5 * 5, out_features=120),
			nn.ReLU(),

			# FC2
			nn.Linear(in_features=120, out_features=84),
			nn.ReLU(),

			# FC3
			nn.Linear(in_features=84, out_features=10)
		)


	def forward(self, x):
		logits = self.net(x)
		return logits






In [3]:
# DATASET
train_data = datasets.MNIST(
	root='./data',
	train=False,
	download=True,
	transform=ToTensor()
)

test_data = datasets.MNIST(
	root='./data',
	train=False,
	download=True,
	transform=ToTensor()
)


# PREPROCESS
batch_size = 256
train_dataloader = DataLoader(dataset=train_data, batch_size=batch_size)
test_dataloader = DataLoader(dataset=test_data, batch_size=batch_size)
for X, y in train_dataloader:
	print(X.shape)		# torch.Size([256, 1, 28, 28])
	print(y.shape)		# torch.Size([256])
	break


# MODEL
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = LeNet5().to(device)


# TRAIN MODEL
loss_func = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=model.parameters())

def train(dataloader, model, loss_func, optimizer, epoch):
	model.train()
	data_size = len(dataloader.dataset)
	for batch, (X, y) in enumerate(dataloader):
		X, y = X.to(device), y.to(device)

		y_hat = model(X)
		loss = loss_func(y_hat, y)

		optimizer.zero_grad()
		loss.backward()
		optimizer.step()

	loss, current = loss.item(), batch * len(X)
	print(f'EPOCH{epoch+1}\tloss: {loss:>7f}', end='\t')


# Test model
def test(dataloader, model, loss_fn):
	size = len(dataloader.dataset)
	num_batches = len(dataloader)
	model.eval()
	test_loss, correct = 0, 0
	with torch.no_grad():
		for X, y in dataloader:
			X, y = X.to(device), y.to(device)
			pred = model(X)
			test_loss += loss_fn(pred, y).item()
			correct += (pred.argmax(1) == y).type(torch.float).sum().item()
	test_loss /= num_batches
	correct /= size
	print(f'Test Error: Accuracy: {(100 * correct):>0.1f}%, Average loss: {test_loss:>8f}\n')


if __name__ == '__main__':
	epoches = 20
	for epoch in range(epoches):
		train(train_dataloader, model, loss_func, optimizer, epoch)
		test(test_dataloader, model, loss_func)

	# Save models
	torch.save(model.state_dict(), 'model.pth')
	print('Saved PyTorch LeNet5 State to model.pth')


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data\MNIST\raw\train-images-idx3-ubyte.gz


100.0%


Extracting ./data\MNIST\raw\train-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data\MNIST\raw\train-labels-idx1-ubyte.gz


102.8%


Extracting ./data\MNIST\raw\train-labels-idx1-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data\MNIST\raw\t10k-images-idx3-ubyte.gz


100.0%


Extracting ./data\MNIST\raw\t10k-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz


112.7%


Extracting ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw

torch.Size([256, 1, 28, 28])
torch.Size([256])
EPOCH1	loss: 0.660427	Test Error: Accuracy: 69.6%, Average loss: 0.894932

EPOCH2	loss: 0.277688	Test Error: Accuracy: 86.5%, Average loss: 0.422568

EPOCH3	loss: 0.092691	Test Error: Accuracy: 91.1%, Average loss: 0.287984

EPOCH4	loss: 0.053176	Test Error: Accuracy: 92.8%, Average loss: 0.232429

EPOCH5	loss: 0.030145	Test Error: Accuracy: 93.6%, Average loss: 0.203683

EPOCH6	loss: 0.029470	Test Error: Accuracy: 95.1%, Average loss: 0.155091

EPOCH7	loss: 0.019340	Test Error: Accuracy: 96.0%, Average loss: 0.127239

EPOCH8	loss: 0.014191	Test Error: Accuracy: 96.7%, Average loss: 0.108613

EPOCH9	loss: 0.010716	Test Error: Accuracy: 97.1%, Average loss: 0.095264

EPOCH10	loss: 0.007580	Test Error: Accuracy: 97.4%, Average loss: 0.084857

EPOCH11	loss: 0.005558	Test Error: Accuracy: 97.6%, Average loss: 0.076358

EPOCH12	loss: 0.004220	Test Error: Accuracy: 97.8%,