In [None]:
import torch
from inception_model import InceptionSham
from torchvision import datasets, transforms
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
import tqdm

In [None]:
def iterate_minibatches(inputs, targets, batchsize, shuffle=False):
    assert len(inputs) == len(targets)
    if shuffle:
        indices = np.random.permutation(len(inputs))
    for start_idx in trange(0, len(inputs) - batchsize + 1, batchsize):
        if shuffle:
            excerpt = indices[start_idx:start_idx + batchsize]
        else:
            excerpt = slice(start_idx, start_idx + batchsize)
        yield inputs[excerpt], targets[excerpt]
        
def weights_init(m):
    classname = m.__class__.__name__
    if (classname.find('Conv')) != -1 and (classname.find('Basic') == -1):
        m.weight.data.normal_(0,
                              0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1,
                              0.02)
        m.bias.data.fill_(0)

In [None]:
BATCH_SIZE = 30
N_EPOCH = 1

In [None]:
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST("../data", train=True, download=False, transform=transforms.ToTensor()), 
    batch_size=BATCH_SIZE,
    shuffle=True)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST("../data", train=False, download=False, transform=transforms.ToTensor()),
    batch_size=BATCH_SIZE,
    shuffle=True)

In [None]:
model = InceptionSham(num_classes=10, input_nc=1, dropout=0.5)
if torch.cuda.is_available():
    model.cuda()
opt = torch.optim.Adam(model.parameters())
model.apply(weights_init)
print('inited')

In [None]:
def train(model, n_epoch):
    model.train()
    for epoch in range(n_epoch):
        for i, (data, target) in tqdm.tqdm(enumerate(train_loader)):
            if torch.cuda.is_available():
                data = data.cuda()
                target = target.cuda()
            data = Variable(data)
            target = Variable(target)

            opt.zero_grad()
            output = model(data)
            loss = F.nll_loss(F.log_softmax(output, dim=1), target)
            loss.backward()
            opt.step()

In [None]:
def test(model):
    model.eval()
    
    accuracy = []
    for i, (data, target) in enumerate(test_loader):
        if torch.cuda.is_available():
            data = data.cuda()
            target = target.cuda()
        data, target = Variable(data), Variable(target)
        output = model(data).data.max(1)
        accuracy.append(torch.sum(model(data).data.max(1)[1] == target.data) / BATCH_SIZE)
        if i == 15:
            break
    return np.mean(accuracy)

In [None]:
train(model, n_epoch=1)

In [None]:
test(model)