In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms

torch.set_printoptions(linewidth=120)
torch.set_grad_enabled(True)

from torch.utils.tensorboard import SummaryWriter

from itertools import product


In [2]:
print (torch.__version__)
print(torchvision.__version__)

1.6.0+cu101
0.7.0+cu101


In [3]:
def get_num_correct(preds,labels):
    return preds.argmax(dim=1).eq(labels).sum().item()

In [4]:

class Network(nn.Module):   
    def __init__(self):
        super().__init__()
        self.conv1=nn.Conv2d(in_channels=1,out_channels=6,kernel_size=5)
        self.conv2=nn.Conv2d(in_channels=6,out_channels=12,kernel_size=5)
        
        self.fc1=nn.Linear(in_features=12*4*4,out_features=120)
        self.fc2=nn.Linear(in_features=120,out_features=60)
        self.out=nn.Linear(in_features=60,out_features=10)
        
    def forward(self,t):
        t=F.relu(self.conv1(t))
        t=F.max_pool2d(t,kernel_size=2,stride=2)
        
        t=F.relu(self.conv2(t))
        t=F.max_pool2d(t,kernel_size=2,stride=2)
        
        t=t.flatten(start_dim=1)
        t=F.relu(self.fc1(t))
        
        t=F.relu(self.fc2(t))
        
        t=self.out(t)
        return t

In [5]:
train_set=torchvision.datasets.FashionMNIST(
         root='./data'
         ,train=True
         ,download=True
         ,transform=transforms.Compose([
             transforms.ToTensor()]))

In [6]:
train_loader=torch.utils.data.DataLoader(train_set,batch_size=100,shuffle=True)


In [7]:
tb=SummaryWriter()
network=Network()
images,lables=next(iter(train_loader))
grid=torchvision.utils.make_grid(images)
tb.add_image('images',grid)
tb.add_graph(network,images)
tb.close()

In [13]:
parameters=dict(
            lr=[0.01,0.001]
            ,batch_size=[10,100,1000]
            ,shuffle=[True,False])
param_values=[v for v in parameters.values()]
param_values

[[0.01, 0.001], [10, 100, 1000], [True, False]]

In [15]:
batch_size=100
lr=0.01

network=Network()

train_loader=torch.utils.data.DataLoader(train_set,batch_size=batch_size,shuffle=True)
images,lables=next(iter(train_loader))
grid=torchvision.utils.make_grid(images)
for lr,batch_size,shuffle in product(*param_values):
    
    optimizer=optim.Adam(network.parameters(),lr=lr)    
    comment=f'batch_size= {batch_size} lr={lr}'
    tb=SummaryWriter(comment=comment)
    tb.add_image('images',grid)
    tb.add_graph(network,images)

    for epoch in range(10):
        total_loss=0
        total_correct=0
        for batch in train_loader:
            images,labels=batch

            preds=network(images)
            loss=F.cross_entropy(preds, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss+=loss.item()* batch_size
            total_correct+=get_num_correct(preds,labels)

        tb.add_scalar("Loss ",total_loss,epoch)
        tb.add_scalar("Number Correct ",total_correct,epoch)
        tb.add_scalar("Accuracy ",total_correct/len(train_set),epoch)

        #tb.add_histogram("conv1.bias ",network.conv1.bias,epoch)
        #tb.add_histogram("conv1.weight", network.conv1.weight,epoch)
        #tb.add_histogram("conv1.weight.grad", network.conv1.weight.grad,epoch)

        for name,weight in network.named_parameters():
            tb.add_histogram(name,weight,epoch)
            tb.add_histogram(f'{name}.grad',weight.grad,epoch)

        print("epoch ",epoch, " totalcorrect: ",total_correct, "total_loss: ",total_loss)
tb.close()

epoch  0  totalcorrect:  47610 total_loss:  3324.562201499939
epoch  1  totalcorrect:  51578 total_loss:  2261.6957055032253
epoch  2  totalcorrect:  52211 total_loss:  2073.266861587763
epoch  3  totalcorrect:  52624 total_loss:  1982.511407583952
epoch  4  totalcorrect:  52793 total_loss:  1930.6244206428528
epoch  5  totalcorrect:  52943 total_loss:  1861.9544284045696
epoch  6  totalcorrect:  53085 total_loss:  1822.0839728415012
epoch  7  totalcorrect:  53305 total_loss:  1778.5763147473335
epoch  8  totalcorrect:  53230 total_loss:  1798.454464301467
epoch  9  totalcorrect:  53440 total_loss:  1750.819897428155
epoch  0  totalcorrect:  53338 total_loss:  1799.2540091276169
epoch  1  totalcorrect:  53609 total_loss:  1716.7603392153978
epoch  2  totalcorrect:  53669 total_loss:  1698.791280835867
epoch  3  totalcorrect:  53712 total_loss:  1686.9626589119434
epoch  4  totalcorrect:  53741 total_loss:  1694.0613471716642
epoch  5  totalcorrect:  53906 total_loss:  1637.085241153836