In [53]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision
import torchvision.transforms as transforms

torch.set_printoptions(linewidth = 120)

In [54]:
print(torch.__version__)
print(torchvision.__version__)

1.6.0
0.7.0


In [55]:
train_set = torchvision.datasets.FashionMNIST(
    root='./data'  
    ,train=True    
    ,download=True 
    ,transform=transforms.Compose([
        transforms.ToTensor()
    ])
)

In [56]:
class Network(nn.Module):
        
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)

        self.fc1 = nn.Linear(in_features=12 * 4 * 4, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=60)
        self.out = nn.Linear(in_features=60, out_features=10)


    def forward(self, t):
     
        t = F.relu(self.conv1(t) )
        t = F.max_pool2d(t, kernel_size = 2, stride =2)

        t = F.relu(self.conv2(t))
        t = F.max_pool2d(t, kernel_size = 2, stride =2)

        t = F.relu(self.fc1(t.reshape(-1, 12*4*4)))
        t = F.relu(self.fc2(t))
        t = self.out(t)
        return t

In [57]:
torch.set_grad_enabled(False) 

<torch.autograd.grad_mode.set_grad_enabled at 0x106725c1100>

In [58]:
network = Network()

In [59]:
data_loader = torch.utils.data.DataLoader(train_set
    ,batch_size=10
    ,shuffle=True
)

In [60]:
batch = next(iter(data_loader))

In [61]:
images, labels = batch

In [62]:
images.shape

torch.Size([10, 1, 28, 28])

In [63]:
labels.shape

torch.Size([10])

In [64]:
preds = network(images)

In [65]:
preds.shape

torch.Size([10, 10])

In [66]:
preds

tensor([[-0.0794,  0.0506,  0.0625,  0.0880, -0.0026,  0.0605, -0.1002, -0.0415,  0.1046, -0.0426],
        [-0.0735,  0.0567,  0.0643,  0.0848, -0.0127,  0.0578, -0.1037, -0.0428,  0.1046, -0.0458],
        [-0.0713,  0.0530,  0.0676,  0.0891, -0.0083,  0.0604, -0.1038, -0.0403,  0.1058, -0.0414],
        [-0.0760,  0.0571,  0.0695,  0.0904, -0.0104,  0.0577, -0.1044, -0.0420,  0.1093, -0.0425],
        [-0.0767,  0.0501,  0.0646,  0.0871, -0.0080,  0.0577, -0.1009, -0.0451,  0.1048, -0.0432],
        [-0.0739,  0.0564,  0.0633,  0.0846, -0.0131,  0.0603, -0.0993, -0.0431,  0.1023, -0.0470],
        [-0.0715,  0.0504,  0.0712,  0.0881, -0.0067,  0.0569, -0.0995, -0.0473,  0.1027, -0.0445],
        [-0.0730,  0.0507,  0.0658,  0.0800, -0.0050,  0.0563, -0.0997, -0.0452,  0.1069, -0.0470],
        [-0.0768,  0.0566,  0.0654,  0.0871, -0.0133,  0.0614, -0.1014, -0.0420,  0.1045, -0.0456],
        [-0.0743,  0.0530,  0.0681,  0.0878, -0.0073,  0.0568, -0.1020, -0.0465,  0.1064, -0.0461]])

argmax always returns the index value of the highest value in a batch

In [67]:
preds.argmax(dim=1) #eacch no represents the index in each batch where highest value occured 

tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8])

In [68]:
labels 

tensor([5, 8, 1, 3, 7, 4, 7, 5, 6, 7])

In [69]:
preds.argmax(dim=1).eq(labels) #compares the labels tensor with the pred tensor 

tensor([False,  True, False, False, False, False, False, False, False, False])

In [70]:
 preds.argmax(dim=1).eq(labels).sum()

tensor(1)

In [72]:
def get_num_correct(preds, labels):
    return preds.argmax(dim=1).eq(labels).sum().item()

In [73]:
get_num_correct(preds, labels)

1