In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.utils.tensorboard import SummaryWriter
%load_ext tensorboard

In [None]:
BATCH_SIZE = 64
LEARNING_RATE = 0.001
NUM_EPOCHS = 30

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
train_data = DataLoader(datasets.MNIST('data', train=True, download=True, 
                        transform=transforms.Compose([transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))])),
                        batch_size=BATCH_SIZE, shuffle=True)

In [None]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        # 1x28x28
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 32, 5, 1, 2), # 32x28x28
            nn.ReLU(),
            nn.Conv2d(32, 32, 5, 1, 2, bias=False), # 32x28x28
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2, 2), # 32x14x14
            nn.Conv2d(32, 64, 3, 1), # 64x12x12
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, 1, bias=False), # 64x10x10
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2, 2), # 64x5x5
            Flatten(),
            nn.Linear(64*5*5, 256, bias=False), # 256
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(256, 128, bias=False), # 128
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Linear(128, 84, bias=False), # 84
            nn.BatchNorm1d(84),
            nn.ReLU(),
            nn.Dropout(0.25),
            nn.Linear(84, 10), # 10
            nn.LogSoftmax(dim=1)
        )

    def forward(self, x):
        x = self.cnn(x)
        return x

class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

In [None]:
model = Model().to(device)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

writer = SummaryWriter('logs/mnist')
step = 0

In [None]:
def train():
    global step
    for batchIdx, (value, label) in enumerate(train_data):
        value = value.to(device)
        label = label.to(device)
        output = model(value)
        loss = F.nll_loss(output, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        writer.add_scalar("loss", loss, step)
        step += 1

In [None]:
%tensorboard --logdir logs

In [None]:
for epoch in range(NUM_EPOCHS):
    train(epoch)

In [None]:
model.eval()
torch.save(model.state_dict(), 'Pytorch.pth')