In [0]:
import torch
import torchvision
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets , transforms
import torch.nn.functional as F
import torch.nn as nn


import torch.optim as optim

In [2]:
torch.set_printoptions(linewidth=120)
torch.set_grad_enabled(True)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f2702d45c50>

In [0]:
print(torch.__version__)
print(torchvision.__version__)

1.1.0
0.3.0


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [0]:
def get_num_correct(preds, labels):
    return preds.argmax(dim=1).eq(labels).sum().item()

In [5]:
train_set = torchvision.datasets.FashionMNIST(
        root = './data',
        train=True,
        download=True,
        transform=transforms.Compose([
                transforms.ToTensor()
                ])
        )

  0%|          | 0/26421880 [00:00<?, ?it/s]

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz


26427392it [00:00, 41172558.88it/s]                              


Extracting ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz


32768it [00:00, 564009.46it/s]
  0%|          | 16384/4422102 [00:00<00:29, 148366.12it/s]

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
Extracting ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


4423680it [00:00, 18675932.42it/s]                           
8192it [00:00, 99084.81it/s]


Extracting ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
Processing...
Done!


In [0]:
train_loader = torch.utils.data.DataLoader(train_set,
                                           batch_size=64,
                                           shuffle=True)

In [0]:
writer = SummaryWriter()

In [0]:
class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)  
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)  
        self.fc1 = nn.Linear(in_features=4*4*12, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=60)
        self.out = nn.Linear(in_features=60, out_features=10)

    def forward(self, x):
        
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, kernel_size=2, stride=2)
        
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, kernel_size=2, stride=2)
        
        x = x.flatten(start_dim=1)
        x = F.relu(self.fc1(x))
        
        x = F.relu(self.fc2(x))
        
        x = self.out(x)
        
        return x

In [0]:
network = Network().to(device)
optimizer = optim.SGD(network.parameters(), lr=0.01)

In [11]:
for epoch in range(20):
    
    total_loss = 0
    total_correct = 0
    
    for images, labels in train_loader:
      
        images = images.to(device)
        labels = labels.to(device)
      
        
        preds = network(images)
        loss = F.cross_entropy(preds ,labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        total_correct += get_num_correct(preds, labels)
    
    writer.add_scalar('Loss',total_loss, epoch)
    writer.add_scalar('Number Correct',total_correct, epoch)
    writer.add_scalar('Accuracy', total_correct/len(train_set) , epoch)
    
    writer.add_histogram('conv1.bias', network.conv1.bias , epoch)
    writer.add_histogram('conv1.weight', network.conv1.weight , epoch)
    writer.add_histogram('conv1.weight.grad', network.conv1.weight.grad ,epoch)
    
    
    print('epoch: ',epoch ,'  total_correct: ',total_correct , 'total_loss',total_loss )


writer.close()

epoch:  0   total_correct:  10111 total_loss 2106.5164833068848
epoch:  1   total_correct:  37289 total_loss 931.2566511631012
epoch:  2   total_correct:  42968 total_loss 696.8557475507259
epoch:  3   total_correct:  45305 total_loss 607.394151777029
epoch:  4   total_correct:  46769 total_loss 552.4441986083984
epoch:  5   total_correct:  47782 total_loss 513.5090279579163
epoch:  6   total_correct:  48563 total_loss 486.10691244900227
epoch:  7   total_correct:  49081 total_loss 464.61606496572495
epoch:  8   total_correct:  49477 total_loss 445.9686411470175
epoch:  9   total_correct:  49824 total_loss 431.4088762551546
epoch:  10   total_correct:  50193 total_loss 417.9538874179125
epoch:  11   total_correct:  50449 total_loss 406.69760762155056
epoch:  12   total_correct:  50676 total_loss 396.4448794722557
epoch:  13   total_correct:  50925 total_loss 385.40708957612514
epoch:  14   total_correct:  51179 total_loss 376.37674854695797
epoch:  15   total_correct:  51386 total_loss