In [1]:
import numpy as np 
import torch 
import torch.nn as nn 
from torchvision import datasets, transforms 
from torch.autograd import Variable 
from torch.nn import init
import torch.nn.functional as F
import matplotlib.pyplot as plt 
import seaborn as sns
%matplotlib inline

In [2]:
class ListModule(object):
    def __init__(self, module, prefix, *args):
        self.module = module
        self.prefix = prefix
        self.num_module = 0
        for new_module in args:
            self.append(new_module)

    def append(self, new_module):
        if not isinstance(new_module, nn.Module):
            raise ValueError('Not a Module')
        else:
            self.module.add_module(self.prefix + str(self.num_module), new_module)
            self.num_module += 1

    def __len__(self):
        return self.num_module

    def __getitem__(self, i):
        if i < 0 or i >= self.num_module:
            raise IndexError('Out of bound')
        return getattr(self.module, self.prefix + str(i))

In [3]:
class rectifier_mlp(nn.Module):
    def __init__(self):
        super(rectifier_mlp, self).__init__()
        self.fc1 = nn.Linear(784, 1024)
        self.fc2 = nn.Linear(1024, 10)


    def forward(self, x): 
        x = x.view(-1, 784)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=0)

In [4]:
class maxout_mlp(nn.Module):
    def __init__(self, num_units=2):
        super(maxout_mlp, self).__init__()
        self.fc1_list = ListModule(self, "fc1_")
        self.fc2_list = ListModule(self, "fc2_")
        for _ in range(num_units):
            self.fc1_list.append(nn.Linear(784, 1024))
            self.fc2_list.append(nn.Linear(1024, 10))

    def forward(self, x): 
        x = x.view(-1, 784)
        x = self.maxout(x, self.fc1_list)
        x = F.dropout(x, training=self.training)
        x = self.maxout(x, self.fc2_list)
        return F.log_softmax(x, dim=0)

    def maxout(self, x, layer_list):
        max_output = layer_list[0](x)
        for _, layer in enumerate(layer_list, start=1):
            max_output = torch.max(max_output, layer(x))
        return max_output

In [5]:
class rectifier_conv_net(nn.Module):
    def __init__(self):
        super(rectifier_conv_net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 5, padding=2)
        self.conv2 = nn.Conv2d(32, 64, 5, padding=2)
        self.fc1 = nn.Linear(64*7*7, 1024)
        self.fc2 = nn.Linear(1024, 10)


    def forward(self, x): 
        x = F.max_pool2d(F.relu(self.conv1(x)), 2)
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, 64*7*7)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=0)

In [6]:
class maxout_conv_net(nn.Module):
    def __init__(self, num_units=2):
        super(maxout_conv_net, self).__init__()
        self.conv1_list = ListModule(self, "conv1_")
        self.conv2_list = ListModule(self, "conv2_")
        self.fc1_list = ListModule(self, "fc1_")
        self.fc2_list = ListModule(self, "fc2_")
        for _ in range(num_units):
            self.conv1_list.append(nn.Conv2d(1, 32, 5, padding=2))
            self.conv2_list.append(nn.Conv2d(32, 64, 5, padding=2))
            self.fc1_list.append(nn.Linear(64*7*7, 1024))
            self.fc2_list.append(nn.Linear(1024, 10))

    def forward(self, x): 
        x = F.max_pool2d(self.maxout(x, self.conv1_list), 2)
        x = F.max_pool2d(self.maxout(x, self.conv2_list), 2)
        x = x.view(-1, 64*7*7)
        x = self.maxout(x, self.fc1_list)
        x = F.dropout(x, training=self.training)
        x = self.maxout(x, self.fc2_list)
        return F.log_softmax(x, dim=0)

    def maxout(self, x, layer_list):
        max_output = layer_list[0](x)
        for _, layer in enumerate(layer_list, start=1):
            max_output = torch.max(max_output, layer(x))
        return max_output

In [7]:
def train(epoch,  net, train_loss, train_acc): 
    optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
    net.train()
    for batch_idx, (data, target) in enumerate(train_loader): 
        if cuda:
            data, target = Variable(data).cuda(0), Variable(target).cuda(0)
        else:
            data, target = Variable(data), Variable(target)
            
        optimizer.zero_grad()
        output = net(data)
        loss = F.cross_entropy(output, target)
        train_loss.append(loss.item())
        loss.backward()
        optimizer.step()
        prediction = output.data.max(1)[1]
        accuracy = prediction.eq(target.data).sum()*1.0/batch_size*100.0 
        train_acc.append(accuracy)
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

In [8]:
def test(epoch, net):
    net.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            if cuda:
                data, target = data.cuda(), target.cuda()
            data, target = Variable(data), Variable(target)
            output = net(data)
            test_loss += F.cross_entropy(output, target).data.item()
            pred = output.data.max(1)[1] # get the index of the max log-probability
            correct += pred.eq(target.data).cpu().sum()

    test_loss = test_loss
    test_loss /= len(test_loader) # loss function already averages over batch size
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [10]:
learning_rate = 0.0001
batch_size = 50
n_epochs = 3
cuda = torch.cuda.is_available()

In [11]:
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../Data_MNIST', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                       ])),
                       batch_size=batch_size, shuffle=True)

In [12]:
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../Data_MNIST', train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                       ])),
                       batch_size=batch_size, shuffle=True)

In [13]:
if cuda:
    net_relu = rectifier_conv_net().cuda()
else:
    net_relu = rectifier_conv_net()
relu_loss = [] 
relu_acc = []
for epoch in range(n_epochs): 
    train(epoch, net_relu, relu_loss, relu_acc)
    test(epoch, net_relu)


Test set: Average loss: 0.0851, Accuracy: 9725/10000 (97%)


Test set: Average loss: 0.0545, Accuracy: 9814/10000 (98%)


Test set: Average loss: 0.0501, Accuracy: 9814/10000 (98%)



In [None]:
if cuda:
    net_maxout = maxout_conv_net(num_units=5).cuda() #this uses 5 "maxout units" per "layer" 
else:
    net_maxout = maxout_conv_net(num_units=5)
maxout_loss = [] 
maxout_acc = []
for epoch in range(n_epochs): 
    train(epoch, net_maxout, maxout_loss, maxout_acc)
    test(epoch, net_maxout)

In [None]:
if cuda:
    mlp_relu = rectifier_mlp().cuda() 
else:
    mlp_relu = rectifier_mlp()
relu_mlp_loss = [] 
relu_mlp_acc = []
for epoch in range(n_epochs): 
    train(epoch, mlp_relu, relu_mlp_loss, relu_mlp_acc)
    test(epoch, mlp_relu)

In [None]:
if cuda:
    mlp_maxout = maxout_mlp(num_units=5).cuda() 
else:
    mlp_maxout = maxout_mlp(num_units=5)
maxout_mlp_loss = [] 
maxout_mlp_acc = []
for epoch in range(n_epochs): 
    train(epoch, mlp_maxout, maxout_mlp_loss, maxout_mlp_acc)
    test(epoch, mlp_maxout)

In [None]:
sns.set_context("paper")
plot1, = plt.plot(maxout_mlp_loss, label='maxout_mlp')
plot2, = plt.plot(relu_mlp_loss, label='relu_mlp')
plt.legend(handles=[plot1, plot2])

In [None]:
plot1, = plt.plot(np.arange(len(maxout_mlp_acc)), maxout_mlp_acc, label='maxout_conv')
plot2, = plt.plot(np.arange(len(relu_mlp_acc)), relu_mlp_acc, label='relu_conv')
plt.legend(handles=[plot1, plot2])