In [1]:
%reload_ext tensorboard

import torch.nn as nn
import torch
import numpy as np
import torchvision
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from torch.utils.tensorboard import SummaryWriter

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
batch_size = 64
learning_rate = 0.01
num_epochs = 8

In [4]:
train_ds = torchvision.datasets.FashionMNIST(
    './data', download=True, train=True, transform=torchvision.transforms.ToTensor())
test_ds = torchvision.datasets.FashionMNIST(
    './data', transform=torchvision.transforms.ToTensor())

train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
test_dl = DataLoader(test_ds, batch_size=batch_size)

In [5]:
writer = SummaryWriter(log_dir='runs/mnist_logs')


In [6]:

# Visualize Data

train_iter = iter(train_dl)
samples, labels = train_iter.next()
print (samples.shape)

torch.Size([64, 1, 28, 28])


In [7]:
grid = torchvision.utils.make_grid(samples)
print (grid.shape)
writer.add_image('samples', grid, 0)

torch.Size([3, 242, 242])


In [8]:
class FashionClassifier(nn.Module):
    def __init__(self, input_channels, num_classes):
        super(FashionClassifier, self).__init__()
        self.conv1 = nn.Conv2d(input_channels, 6, 5)
        self.pool = nn.MaxPool2d(2,2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.lin1 = nn.Linear(16*4*4, 150)
        self.lin2 = nn.Linear(150, 84)
        self.lin3 = nn.Linear(84, num_classes)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(-1, 16*4*4)
        x = F.relu(self.lin1(x))
        x = F.relu(self.lin2(x))
        x = F.relu(self.lin3(x))
        return x

In [9]:
model = FashionClassifier(samples.shape[1], 10)
writer.add_graph(model, samples)

criterion = nn.CrossEntropyLoss()
optim = torch.optim.SGD(model.parameters(), lr=learning_rate)


In [10]:
def run_eval(model, test_dl, epoch):

    with torch.no_grad():
        num_correct = 0
        num_samples = 0
        for i, (samples,labels) in enumerate(test_dl):
            images = samples.to(device)
            y = labels.to(device)
            y_pred = model(images)
            _, predicted = torch.max(y_pred, 1)
            num_correct += (predicted == y).sum().item()
            num_samples += labels[0]

        acc = (100 * num_correct) / num_samples
        writer.add_scalar('mnist_val_acc', acc, epoch)

In [11]:
for epoch in range(num_epochs):
    for i,(samples,labels) in enumerate(train_dl):
        images = samples.to(device)
        y = labels.to(device)
        pred = model(images)
        loss = criterion(pred, y)
        optim.zero_grad()
        loss.backward()
        optim.step()

        if (i + 1) % 100 == 0:
            print(f'epoch {epoch} / num_epochs {num_epochs-1}, step {i + 1}, training loss = {loss.item():.4f}')
    
    run_eval(model, test_dl, epoch)

writer.close() 

epoch 0 / num_epochs 7, step 100, training loss = 2.3059
epoch 0 / num_epochs 7, step 200, training loss = 2.3013
epoch 0 / num_epochs 7, step 300, training loss = 2.2997
epoch 0 / num_epochs 7, step 400, training loss = 2.2945
epoch 0 / num_epochs 7, step 500, training loss = 2.2926
epoch 0 / num_epochs 7, step 600, training loss = 2.2915
epoch 0 / num_epochs 7, step 700, training loss = 2.2743
epoch 0 / num_epochs 7, step 800, training loss = 2.2419
epoch 0 / num_epochs 7, step 900, training loss = 2.2129
epoch 1 / num_epochs 7, step 100, training loss = 2.0550
epoch 1 / num_epochs 7, step 200, training loss = 1.9104
epoch 1 / num_epochs 7, step 300, training loss = 1.7568
epoch 1 / num_epochs 7, step 400, training loss = 2.0054
epoch 1 / num_epochs 7, step 500, training loss = 1.6775
epoch 1 / num_epochs 7, step 600, training loss = 1.4911
epoch 1 / num_epochs 7, step 700, training loss = 1.3871
epoch 1 / num_epochs 7, step 800, training loss = 1.2499
epoch 1 / num_epochs 7, step 90

In [12]:
%tensorboard --logdir=runs/mnist_logs

Reusing TensorBoard on port 6006 (pid 3991), started 5:04:57 ago. (Use '!kill 3991' to kill it.)