In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.utils.tensorboard import SummaryWriter
%load_ext tensorboard

In [None]:
BATCH_SIZE = 64
LEARNING_RATE = 0.001
NUM_EPOCHS = 100

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
train_data = DataLoader(datasets.MNIST('data', train=True, download=True, 
                transform=transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.5,), (0.5,))
                ])),
                batch_size=BATCH_SIZE, 
                shuffle=True
            )

In [None]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.seq = nn.Sequential(
            self._block(1, 32, 3),
            self._block(32, 48, 3),
            self._block(48, 64, 3),
            self._block(64, 80, 3),
            self._block(80, 96, 3),
            self._block(96, 112, 3),
            self._block(112, 128, 3),
            self._block(128, 144, 3),
            self._block(144, 160, 3),
            self._block(160, 176, 3),
            Flatten(),
            nn.Linear(11264, 10, bias=False),
            nn.BatchNorm1d(10)
        )

    def _block(self, input_dim, output_dim, kernel_size):
        return nn.Sequential(
            nn.Conv2d(input_dim, output_dim, kernel_size, bias=False),
            nn.BatchNorm2d(output_dim),
            nn.ReLU()
        )

    def forward(self, x):
        x = self.seq(x)
        return F.log_softmax(x, dim=1)

class Flatten(nn.Module):
    def forward(self, x):
        return torch.flatten(x.permute(0, 2, 3, 1), 1)

In [None]:
model = Model().to(device)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

writer = SummaryWriter('logs')
step = 0

In [None]:
def train():
    global step
    for batchIdx, (value, label) in enumerate(train_data):
        value = value.to(device)
        label = label.to(device)
        output = model(value)
        loss = F.nll_loss(output, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        writer.add_scalar("loss", loss, step)
        step += 1

In [None]:
%tensorboard --logdir logs

In [None]:
for epoch in range(NUM_EPOCHS):
    train(epoch)

In [None]:
model.eval()
torch.save(model.state_dict(), 'Pytorch.pth')