In [17]:
import torch
import torchvision
import torch.nn.functional as F
import torch.nn as nn
from torchvision import datasets, transforms
import os


# CNN class

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5) # 5x5 kernel
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 
        self.conv2_drop = nn.Dropout2d() # Dropout ensemble
        self.fc1 = nn.Linear(320, 50) # Affine Layer
        self.fc2 = nn.Linear(50, 10) # Affine Layer



#forward method
    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320) # convert into a 1x320 row vector (vectorize the tensor)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


    def name(self):
        return "CNN"

root = './data'
if not os.path.exists(root): # if path does not exist => create path
    os.mkdir(root)
batch_size = 32 # batch size usually are size of powers of 2
#transformations applied to the data
trans = transforms.Compose([transforms.ToTensor()])

#training set with their respective 
train_set = datasets.MNIST(root=root,train=True,transform =trans,download=True)
test_set = datasets.MNIST(root=root,train=False,transform =trans,download=True)

#train loader
train_loader = torch.utils.data.DataLoader(
                 dataset=train_set,
                 batch_size=batch_size,
                 shuffle=True)
#test loader
test_loader = torch.utils.data.DataLoader(
                dataset=test_set,
                batch_size=batch_size,
                shuffle=True)


cnn = CNN() #our model
(x,target) = iter(train_loader).next()
output = cnn(x) # forward
print("Raw output: %s" %output.data)
_,val  = torch.max(output.data,1)

print("Torch max : %s"  %val)



Raw output: tensor([[-2.2852, -2.3701, -2.3718, -2.3888, -2.4036, -2.3496, -2.2624,
         -2.1046, -2.3356, -2.1967],
        [-2.2538, -2.3363, -2.3198, -2.3856, -2.4149, -2.2131, -2.3117,
         -2.1386, -2.4258, -2.2640],
        [-2.3204, -2.4190, -2.2878, -2.3541, -2.3447, -2.3508, -2.2095,
         -2.0742, -2.4645, -2.2569],
        [-2.2460, -2.3709, -2.2903, -2.4660, -2.4312, -2.2263, -2.2917,
         -2.1991, -2.3573, -2.1885],
        [-2.1918, -2.3290, -2.4032, -2.3643, -2.4201, -2.3337, -2.3775,
         -2.0779, -2.3577, -2.2256],
        [-2.2018, -2.3665, -2.3591, -2.4077, -2.3594, -2.2703, -2.3117,
         -2.1353, -2.3746, -2.2734],
        [-2.2595, -2.2957, -2.3122, -2.3859, -2.4773, -2.2062, -2.3077,
         -2.1928, -2.3644, -2.2568],
        [-2.2544, -2.3669, -2.3523, -2.3316, -2.3828, -2.3558, -2.2710,
         -2.0695, -2.4696, -2.2273],
        [-2.2654, -2.3750, -2.3281, -2.3811, -2.4091, -2.3108, -2.2799,
         -2.1433, -2.4046, -2.1686],
       