In [1]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

from datasets import limits

In [2]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        #print(data.size())
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
            pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [3]:
use_cuda = torch.cuda.is_available()
seed = 1
torch.manual_seed(seed)
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
log_interval = 50

In [4]:
train_loader = torch.utils.data.DataLoader(
    limits.LimitDataset(datasets.MNIST('./data/mnist', train=True, download=False,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])), 1000),batch_size=128, shuffle=True, **kwargs)

In [5]:
test_loader = torch.utils.data.DataLoader(
    limits.LimitDataset(datasets.MNIST('./data/mnist', train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])), 500),
    batch_size=128, shuffle=True, **kwargs)

In [6]:
model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

In [7]:
epochs = 200
for epoch in range(1, epochs + 1):
    train(model, device, train_loader, optimizer, epoch)
    test(model, device, test_loader)


Test set: Average loss: 2.2937, Accuracy: 69/500 (14%)


Test set: Average loss: 2.2802, Accuracy: 71/500 (14%)


Test set: Average loss: 2.2691, Accuracy: 77/500 (15%)


Test set: Average loss: 2.2588, Accuracy: 87/500 (17%)


Test set: Average loss: 2.2449, Accuracy: 114/500 (23%)


Test set: Average loss: 2.2253, Accuracy: 141/500 (28%)


Test set: Average loss: 2.2019, Accuracy: 160/500 (32%)


Test set: Average loss: 2.1719, Accuracy: 178/500 (36%)


Test set: Average loss: 2.1309, Accuracy: 191/500 (38%)


Test set: Average loss: 2.0755, Accuracy: 218/500 (44%)


Test set: Average loss: 2.0110, Accuracy: 224/500 (45%)


Test set: Average loss: 1.9226, Accuracy: 243/500 (49%)


Test set: Average loss: 1.8287, Accuracy: 252/500 (50%)


Test set: Average loss: 1.7224, Accuracy: 268/500 (54%)


Test set: Average loss: 1.6130, Accuracy: 290/500 (58%)


Test set: Average loss: 1.4944, Accuracy: 307/500 (61%)


Test set: Average loss: 1.4061, Accuracy: 323/500 (65%)


Test set: Average


Test set: Average loss: 0.3361, Accuracy: 445/500 (89%)


Test set: Average loss: 0.3279, Accuracy: 446/500 (89%)


Test set: Average loss: 0.3356, Accuracy: 448/500 (90%)


Test set: Average loss: 0.3311, Accuracy: 453/500 (91%)


Test set: Average loss: 0.3377, Accuracy: 447/500 (89%)


Test set: Average loss: 0.3282, Accuracy: 445/500 (89%)


Test set: Average loss: 0.3226, Accuracy: 449/500 (90%)


Test set: Average loss: 0.3182, Accuracy: 450/500 (90%)


Test set: Average loss: 0.3149, Accuracy: 451/500 (90%)


Test set: Average loss: 0.3136, Accuracy: 451/500 (90%)


Test set: Average loss: 0.3181, Accuracy: 449/500 (90%)


Test set: Average loss: 0.3186, Accuracy: 450/500 (90%)


Test set: Average loss: 0.3206, Accuracy: 450/500 (90%)


Test set: Average loss: 0.3104, Accuracy: 452/500 (90%)


Test set: Average loss: 0.3159, Accuracy: 451/500 (90%)


Test set: Average loss: 0.2990, Accuracy: 453/500 (91%)


Test set: Average loss: 0.2957, Accuracy: 452/500 (90%)


Test set: Ave


Test set: Average loss: 0.2333, Accuracy: 463/500 (93%)


Test set: Average loss: 0.2277, Accuracy: 465/500 (93%)


Test set: Average loss: 0.2229, Accuracy: 464/500 (93%)


Test set: Average loss: 0.2216, Accuracy: 464/500 (93%)


Test set: Average loss: 0.2292, Accuracy: 465/500 (93%)


Test set: Average loss: 0.2265, Accuracy: 464/500 (93%)


Test set: Average loss: 0.2319, Accuracy: 464/500 (93%)


Test set: Average loss: 0.2222, Accuracy: 462/500 (92%)


Test set: Average loss: 0.2278, Accuracy: 466/500 (93%)


Test set: Average loss: 0.2327, Accuracy: 465/500 (93%)


Test set: Average loss: 0.2249, Accuracy: 465/500 (93%)


Test set: Average loss: 0.2287, Accuracy: 465/500 (93%)


Test set: Average loss: 0.2292, Accuracy: 464/500 (93%)


Test set: Average loss: 0.2305, Accuracy: 463/500 (93%)


Test set: Average loss: 0.2238, Accuracy: 466/500 (93%)


Test set: Average loss: 0.2223, Accuracy: 465/500 (93%)


Test set: Average loss: 0.2172, Accuracy: 465/500 (93%)


Test set: Ave