In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

import torchvision
import torchvision.transforms as transforms



In [2]:
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.metrics import confusion_matrix

import pdb

torch.set_printoptions(linewidth = 120)

In [3]:
train_set = torchvision.datasets.FashionMNIST(
    root = './data/FashionMNIST'
    ,train = True
    ,download = True
    ,transform = transforms.Compose([
        transforms.ToTensor()
    ]))

In [4]:
class Network(nn.Module):
    def __init__(self):
        super(Network,self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)
        
        self.fc1 = nn.Linear(in_features=12*4*4, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=60)
        self.out = nn.Linear(in_features=60, out_features=10)
        
    def forward(self,t):
        #(1) input layer
        t = t
        
        #(2) hidden conv layer
        t = self.conv1(t)
        t = F.relu(t)
        t = F.max_pool2d(t, kernel_size=2, stride=2)
        
        #(3) hidden conv layer
        t = self.conv2(t)
        t = F.relu(t)
        t = F.max_pool2d(t, kernel_size=2, stride=2)
        
        #(4) hidden linear layer
        t = t.reshape(-1,12 * 4 * 4)
        t = self.fc1(t)
        t = F.relu(t)
        
        #(5) hidden linear layer
        t = self.fc2(t)
        t = F.relu(t)
        
        #(6) output layer
        t = self.out(t)
        #t = F.softmax(t, dim=1)
        
        return t
    

In [5]:
# network = Network()

In [6]:
# train_loader = torch.utils.data.DataLoader(
#     train_set,
#     batch_size=100
# )
# optimizer = optim.Adam(network.parameters(), lr = 0.01)

In [7]:
def get_num_correct(preds,labels):
    return preds.argmax(dim=1).eq(labels).sum().item()

In [8]:
batch_size_list = [100,1000,10000]
lr_list = [.01,.001,.0001,.00001]

In [11]:
for batch_size in batch_size_list:
    for lr in lr_list:
        network = Network()
        
        train_loader = torch.utils.data.DataLoader(
            train_set, batch_size=batch_size
        )
        optimizer = optim.Adam(network.parameters(), lr=lr)
        
        images,labels = next(iter(train_loader))
        grid = torchvision.utils.make_grid(images)

        comment=f' batch_size={batch_size} lr={lr}'
        tb = SummaryWriter(comment=comment)
        tb.add_image('images', grid)
        tb.add_graph(network, images)

        for epoch in range(5):

            total_loss = 0
            total_correct = 0

            for batch in train_loader:
                images,labels = batch #get batch

                preds = network(images) #pass batch
                loss = F.cross_entropy(preds,labels) # calculate loss

                optimizer.zero_grad() #zero gradients
                loss.backward() # calculate gradients
                optimizer.step()# Update weights

                total_loss += loss.item() * batch_size
                total_correct += get_num_correct(preds,labels)

            tb.add_scalar('Loss', total_loss, epoch)
            tb.add_scalar('Number Correct', total_correct, epoch)
            tb.add_scalar('Accuracy', total_correct / len(train_set), epoch)
            
            for name, param in network.named_parameters():
                tb.add_histogram(name, param, epoch)
                tb.add_histogram(f'{name}.grad',param.grad,epoch)
                
            for name, weight in network.named_parameters():
                tb.add_histogram(name, weight, epoch)
                tb.add_histogram(f'{name}.grad',weight.grad,epoch)
            
            print('epoch:',epoch,
                  'total_correct:', total_correct,
                  'loss:', total_loss)
        tb.close()

epoch: 0 total_correct: 47585 loss: 33033.57497751713
epoch: 1 total_correct: 51485 loss: 23053.107208013535
epoch: 2 total_correct: 52071 loss: 21369.594030082226
epoch: 3 total_correct: 52386 loss: 20374.2847725749
epoch: 4 total_correct: 52562 loss: 19908.495746552944
epoch: 0 total_correct: 42015 loss: 46570.40399312973
epoch: 1 total_correct: 48276 loss: 30957.88463652134
epoch: 2 total_correct: 50276 loss: 26423.177137970924
epoch: 3 total_correct: 51535 loss: 23399.070486426353
epoch: 4 total_correct: 52146 loss: 21560.758033394814
epoch: 0 total_correct: 32649 loss: 82395.9949016571
epoch: 1 total_correct: 42713 loss: 46788.60059380531
epoch: 2 total_correct: 44248 loss: 41884.788912534714
epoch: 3 total_correct: 45226 loss: 39024.42030310631
epoch: 4 total_correct: 45810 loss: 37152.90271937847
epoch: 0 total_correct: 11309 loss: 136967.87161827087
epoch: 1 total_correct: 23829 loss: 126858.8387966156
epoch: 2 total_correct: 30253 loss: 101814.82248306274
epoch: 3 total_correc