In [None]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
import numpy as np

from mnist_models import return_model
best_loss = 0



## Dataset and Training

In [None]:
def train(args, model, device, train_loader, optimizer, epoch, name):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        output = F.log_softmax(output, dim=1)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        torch.save(model.state_dict(), "./mnist_cnn/mnist_{}_{}_{}.pt".format(name, batch_idx+1, epoch))

def test(args, model, device, test_loader, epoch, name):
    model.eval()
    global best_loss
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            #target[target==3] = 0
            #target[target==5] = 2
            data, target = data.to(device), target.to(device)
            output = model(data)
            output = F.log_softmax(output, dim=1)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)

    #print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    #    test_loss, correct, len(test_loader.dataset),
    #    100. * correct / len(test_loader.dataset)))
    if accuracy > best_loss:
        print("Saving Model...")
        best_loss = accuracy
        torch.save(model.state_dict(), "./mnist_cnn/mnist_{}_10_{}.pt".format(name, epoch))
        print('\nTest set {}: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(epoch, test_loss, correct, len(test_loader.dataset),
                    100. * correct / len(test_loader.dataset)))

# Training settings
batch_size = 64
test_batch_size = 1000
epochs = 14
lr = 0.1
sparseness = 0.0
gamma = 0.7
no_cuda = False
seed = 1
log_interval = 100
save_model = True
classes = 10
args = {'epochs':epochs, 'batch-size':batch_size, 'test-batch-size':test_batch_size, 'lr':lr, 'sparseness':sparseness, 'gamma':gamma, 'no-cuda':no_cuda,
       'seed':seed, 'log-interval':log_interval, 'save-model':save_model}
use_cuda = not args['no-cuda'] and torch.cuda.is_available()

torch.manual_seed(args['seed'])

device = torch.device("cuda" if use_cuda else "cpu")

kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

data_training = datasets.MNIST('../data', train=True, download=True,
                               transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))]))

data_testing = datasets.MNIST('../data', train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))]))

idx = torch.tensor(data_training.targets) == 0
for i in range(1,classes):
    idx += torch.tensor(data_training.targets) == i
data_training = torch.utils.data.dataset.Subset(data_training, np.where(idx==1)[0])

idx = torch.tensor(data_testing.targets) == 0
for i in range(1,classes):
    idx += torch.tensor(data_testing.targets) == i
    
data_testing = torch.utils.data.dataset.Subset(data_testing, np.where(idx==1)[0])

train_loader = torch.utils.data.DataLoader(data_training,
    batch_size=args['batch-size'], shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(data_testing,
    batch_size=args['test-batch-size'], shuffle=True, **kwargs)



## Select Model 

In [None]:
name = 'full'
model = return_model(name)
model = model.to(device)
optimizer = optim.Adam(model[-1].parameters(), lr=args['lr'])
best_loss = 0
scheduler = StepLR(optimizer, step_size=1, gamma=args['gamma'])
for epoch in range(1, epochs + 1):
    train(args, model, device, train_loader, optimizer, epoch, name)
    test(args, model, device, test_loader, epoch, name)
    scheduler.step()