In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

import torchvision
import torchvision.transforms as transforms



In [2]:
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.metrics import confusion_matrix

import pdb

torch.set_printoptions(linewidth = 120)

In [3]:
train_set = torchvision.datasets.FashionMNIST(
    root = './data/FashionMNIST'
    ,train = True
    ,download = True
    ,transform = transforms.Compose([
        transforms.ToTensor()
    ]))

In [4]:
class Network(nn.Module):
    def __init__(self):
        super(Network,self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)
        
        self.fc1 = nn.Linear(in_features=12*4*4, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=60)
        self.out = nn.Linear(in_features=60, out_features=10)
        
    def forward(self,t):
        #(1) input layer
        t = t
        
        #(2) hidden conv layer
        t = self.conv1(t)
        t = F.relu(t)
        t = F.max_pool2d(t, kernel_size=2, stride=2)
        
        #(3) hidden conv layer
        t = self.conv2(t)
        t = F.relu(t)
        t = F.max_pool2d(t, kernel_size=2, stride=2)
        
        #(4) hidden linear layer
        t = t.reshape(-1,12 * 4 * 4)
        t = self.fc1(t)
        t = F.relu(t)
        
        #(5) hidden linear layer
        t = self.fc2(t)
        t = F.relu(t)
        
        #(6) output layer
        t = self.out(t)
        #t = F.softmax(t, dim=1)
        
        return t
    

In [5]:
network = Network()

In [6]:
train_loader = torch.utils.data.DataLoader(
    train_set,
    batch_size=100
)
optimizer = optim.Adam(network.parameters(), lr = 0.01)

In [8]:
def get_num_correct(preds,labels):
    return preds.argmax(dim=1).eq(labels).sum().item()

In [9]:
images,labels = next(iter(train_loader))
grid = torchvision.utils.make_grid(images)

tb = SummaryWriter()
tb.add_image('images', grid)
tb.add_graph(network, images)

for epoch in range(50):

    total_loss = 0
    total_correct = 0


    for batch in train_loader:
        images,labels = batch

        preds = network(images)
        loss = F.cross_entropy(preds,labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_correct += get_num_correct(preds,labels)
        
    tb.add_scalar('Loss', total_loss, epoch)
    tb.add_scalar('Number Correct', total_correct, epoch)
    tb.add_scalar('Accuracy', total_correct / len(train_set), epoch)
    
    tb.add_histogram('conv1.bias', network.conv1.bias, epoch)
    tb.add_histogram('conv1.weight', network.conv1.weight, epoch)
    tb.add_histogram(
        'conv1.weight.grad'
        ,network.conv1.weight.grad
        ,epoch
    )

    print('epoch:',epoch,
          'total_correct:', total_correct,
          'loss:', total_loss)
tb.close()

epoch: 0 total_correct: 47988 loss: 319.5780900269747
epoch: 1 total_correct: 51861 loss: 219.90778167545795
epoch: 2 total_correct: 52371 loss: 204.616062566638
epoch: 3 total_correct: 52649 loss: 197.41518412530422
epoch: 4 total_correct: 53012 loss: 188.01862213015556
epoch: 5 total_correct: 53177 loss: 183.28284583985806
epoch: 6 total_correct: 53258 loss: 181.4121125638485
epoch: 7 total_correct: 53294 loss: 181.53220336139202
epoch: 8 total_correct: 53400 loss: 178.20476151257753
epoch: 9 total_correct: 53453 loss: 178.9490063637495
epoch: 10 total_correct: 53633 loss: 170.74077335000038
epoch: 11 total_correct: 53715 loss: 170.99706882983446
epoch: 12 total_correct: 53805 loss: 169.94547184556723
epoch: 13 total_correct: 53877 loss: 167.18902680277824
epoch: 14 total_correct: 53882 loss: 167.71564829349518
epoch: 15 total_correct: 53996 loss: 163.1985973417759
epoch: 16 total_correct: 53948 loss: 165.8886448070407
epoch: 17 total_correct: 53988 loss: 165.68597032874823
epoch: 18