In [None]:
import torch
import torchvision

tt = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
                                     torchvision.transforms.Normalize((0.1307,), (0.3081,))])

mnist_train = torchvision.datasets.MNIST('./files/', train=True, download=True, transform=tt)
mnist_test = torchvision.datasets.MNIST('./files/', train=False, download=True, transform=tt)


train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=20, shuffle=True)
test_loader = torch.utils.data.DataLoader(mnist_test, batch_size = 1000, shuffle=True)

examples = enumerate(train_loader)
batch_idx, (example_data, example_targets) = next(examples)

In [2]:
train_loader

<torch.utils.data.dataloader.DataLoader at 0x7f02d22bfd30>

In [3]:
example_data.shape

torch.Size([20, 1, 28, 28])

In [4]:
from torch import nn
from torch.nn import functional as F
from torch.autograd import Variable
import torch.optim as optim

In [5]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim = 1)

In [6]:
network = Net()

In [7]:
optimizer = optim.SGD(network.parameters(), lr = 0.01)

In [None]:
for iteration in range(10):
  for batch_idx, (data, target) in enumerate(train_loader):

    optimizer.zero_grad()
    output = network(data)
    loss = F.nll_loss(output, target)
    loss.backward()
    optimizer.step()

    if batch_idx % 500 == 0:
      print(loss.item())

2.3389291763305664
1.1801916360855103
0.3298601508140564
0.6380429863929749
0.32825911045074463
0.40180182456970215
0.38519176840782166
0.21867668628692627
0.21635110676288605
0.13684645295143127
0.31843820214271545
0.25247684121131897
0.13615700602531433
0.2451208084821701
0.3534315228462219
0.2231675684452057
0.30634862184524536
0.28645870089530945
0.04708176106214523
0.19946470856666565
0.22912177443504333
0.24374370276927948
0.4475262761116028
0.2408103495836258
0.7664046287536621
0.09364001452922821
0.3772566020488739
0.0638391450047493
0.32588157057762146
0.024475598707795143
0.08929113298654556
0.4940183758735657
0.2506020665168762
0.017952853813767433
0.01199860405176878
0.2861446142196655
0.4728274345397949
0.2366124838590622
1.0166678428649902
0.049511827528476715
0.06639349460601807
0.2823609709739685
0.060056041926145554
0.6406525373458862
