In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms

torch.set_printoptions(linewidth=120)

In [2]:
print(torch.__version__)

1.3.0


In [3]:
def get_num_correct(preds,labels):
    return preds.argmax(dim=1).eq(labels).sum().item()

In [4]:
class Network(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=1,out_channels=6,kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=6,out_channels=12,kernel_size=5)
        
        self.fc1 = nn.Linear(in_features=12*4*4,out_features=120)
        self.fc2 = nn.Linear(in_features=120,out_features=60)
        self.out = nn.Linear(in_features=60,out_features=10)
        
    def forward(self,t):
        t = F.relu(self.conv1(t))
        t = F.max_pool2d(t,kernel_size=2,stride=2)
        
        t = F.relu(self.conv2(t))
        t = F.max_pool2d(t,kernel_size=2,stride=2)
        
        t = t.reshape(-1,12*4*4)
        t = F.relu(self.fc1(t))
        
        t = F.relu(self.fc2(t))
        
        t = self.out(t)
        
        return t
    

In [6]:
tran_set = torchvision.datasets.FashionMNIST(
        root='./data',
        train=True,
        download=True,
        transform=transforms.Compose([
            transforms.ToTensor()
        ])
)

0it [00:00, ?it/s]

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz


100%|█████████▉| 26419200/26421880 [01:22<00:00, 380768.04it/s] 

Extracting ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw



0it [00:00, ?it/s][A

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz



  0%|          | 0/29515 [00:00<?, ?it/s][A
 56%|█████▌    | 16384/29515 [00:00<00:00, 84982.28it/s][A
 83%|████████▎ | 24576/29515 [00:00<00:00, 65391.06it/s][A
32768it [00:01, 29153.15it/s]                           [A

0it [00:00, ?it/s][A

Extracting ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz



  0%|          | 0/4422102 [00:00<?, ?it/s][A
  0%|          | 16384/4422102 [00:00<00:44, 98421.80it/s][A
  1%|          | 32768/4422102 [00:00<00:44, 98264.40it/s][A
  1%|          | 40960/4422102 [00:00<00:59, 74221.27it/s][A
  1%|▏         | 57344/4422102 [00:01<00:54, 80266.68it/s][A
  2%|▏         | 90112/4422102 [00:01<00:44, 97489.16it/s][A
  4%|▎         | 155648/4422102 [00:01<00:33, 125819.84it/s][A
  6%|▌         | 262144/4422102 [00:01<00:25, 166231.02it/s][A
  8%|▊         | 352256/4422102 [00:01<00:19, 209069.03it/s][A
 11%|█         | 475136/4422102 [00:01<00:14, 267710.66it/s][A
 13%|█▎        | 581632/4422102 [00:02<00:11, 324405.25it/s][A
 17%|█▋        | 745472/4422102 [00:02<00:09, 405025.69it/s][A
 22%|██▏       | 983040/4422102 [00:02<00:06, 534671.95it/s][A
 26%|██▌       | 1155072/4422102 [00:02<00:04, 673872.78it/s][A
 29%|██▉       | 1286144/4422102 [00:02<00:04, 750453.60it/s][A
 34%|███▍      | 1507328/4422102 [00:02<00:03, 922373.73it/s][A

Extracting ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz



8192it [00:00, 22092.23it/s]            [A

Extracting ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw
Processing...
Done!





In [10]:
network = Network()

train_loader = torch.utils.data.DataLoader(tran_set,batch_size=100)
optimizer = optim.Adam(network.parameters(),lr=0.01)

for epoch in range(5):
    total_loss =0
    total_correct = 0
    for batch in train_loader:
        images,labels = batch

        preds = network(images)#pass Batch
        loss = F.cross_entropy(preds,labels) #calculte loss

        optimizer.zero_grad() 
        loss.backward() #calculate Grads
        optimizer.step()#updata weights

        total_loss += loss.item()
        total_correct+=get_num_correct(preds,labels)
    print("epoch:",epoch,"total correct:",total_correct,"loss:",total_loss)

epoch: 0 total correct: 46502 loss: 358.44425243139267
epoch: 1 total correct: 51063 loss: 241.92368979752064
epoch: 2 total correct: 51704 loss: 221.7484328597784
epoch: 3 total correct: 52103 loss: 214.53608672320843
epoch: 4 total correct: 52046 loss: 210.90614467859268


In [11]:
total_correct/len(tran_set)

0.8674333333333333