In [None]:
%pylab inline

import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim

from torchvision import datasets
from torchvision.transforms import ToTensor

from tqdm import trange

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
train_data = datasets.MNIST(
	root='data',
	train=True,
	transform=ToTensor(),
	download=True,
);

test_data = datasets.MNIST(
	root='data',
	train=False,
	transform=ToTensor(),
);

In [None]:
print(train_data[0][0].shape)
print(train_data[0][1])

In [None]:
imshow(train_data[0][0].reshape(28, 28))

In [None]:
class ShitNet(nn.Module):
	def __init__(self):
		super(ShitNet, self).__init__()

		self.convs = nn.Sequential(
			nn.Conv2d(1, 32, 4),
			nn.BatchNorm2d(32),
			nn.ReLU(inplace=True),
			nn.MaxPool2d(2),

			nn.Conv2d(32, 64, 4),
			nn.BatchNorm2d(64),
			nn.ReLU(inplace=True),
			nn.MaxPool2d(2),
		)

		self.dropout = nn.Dropout(0.15)

	def set_classifier(self, x_shape):
		lin_size = x_shape[1] * x_shape[2] * x_shape[3]
		self.lins = nn.Sequential(
			nn.Linear(lin_size, 4096),
			nn.ReLU(inplace=True),

			nn.Linear(4096, 2064),
			nn.ReLU(inplace=True),

			nn.Linear(2064, 1032),
			nn.ReLU(inplace=True),
		)

		self.classifier = nn.Linear(1032, 10)

	def forward(self, x):
		x = self.convs(x)
		x = self.dropout(x)
		self.set_classifier(x.shape)
		x = x.reshape(x.shape[0], -1)
		x = self.lins(x)
		x = self.classifier(x)

		return x


In [None]:
model = ShitNet()

optimizer = optim.Adam(model.parameters(), lr=3e-4)

loss_fn = nn.CrossEntropyLoss()

EPOCHS = 10
BATCH = 4

In [None]:

trainloader = torch.utils.data.DataLoader(
	train_data,
	batch_size=BATCH,
	num_workers=12,
	shuffle=True
)

In [None]:
for data in trainloader:
	d, l = data
	print(l[0])
	break

In [None]:
accuracies = []
losses = []

model.train();

for epoch in trange(EPOCHS):
	running_loss = 0.0
	for i, data in enumerate(trainloader):
		optimizer.zero_grad()

		images, labels = data

		out = model(images)

		loss = loss_fn(out, labels)

		loss.backward()

		optimizer.step()

		running_loss += loss.item()

		_, predicted = torch.max(out, 1)

		accuracy = (predicted == labels).float().mean()

		accuracies.append(accuracy.item())
		losses.append(loss.item())
	
	print('ep', epoch, 'loss', running_loss / len(trainloader))

In [None]:
plt.plot(range(len(accuracies)), accuracies, 'g', label='Accuracy')
plt.plot(range(len(losses)), losses, 'b', label='Loss')
plt.title('Accuracy and Loss')
plt.show()

In [None]:
testloader = torch.utils.data.DataLoader(
	test_data,
	batch_size=1,
	num_workers=12,
	shuffle=False,
)

In [None]:
correct = 0
total = 0

with torch.no_grad():
	model.eval();
	for i, data in enumerate(testloader):
		images, labels = data
	
		out = model(images)
	
		_, predicted = torch.max(out, 1)
	
		total += labels.shape[0]
		correct += (predicted == labels).sum().item()

print('total', total)
print('correct', correct)
print('accuracy', round(correct / total, 3))