# Prototype experiment using only dense layers

In [3]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR

In [4]:
# Example DNA using 3 dense layers
DNA = [["D",8192],["D",4096],["D",128]]

In [15]:
# Here we add the input and output size as well (28x28 and 10 for mnist)
class ConstructNet(nn.Module):
    def __init__(self, DNA, input_size=28*28, output_size=10):
        super(ConstructNet, self).__init__()
        self.DNA = DNA
        self.input_size = input_size
        self.output_size = output_size
        self.layers = []
        
        # Append first layer
        self.layers.append(nn.Linear(self.input_size, self.DNA[0][1]))
        self.layers.append(nn.ReLU())
        
        for i in range(1, len(self.DNA)):
            if self.DNA[i][0] == "D":
                # The input size is the output of the last layer
                tmp_input_size = self.last_layer_output_size()
                self.layers.append(nn.Linear(tmp_input_size, self.DNA[i][1]))
                self.layers.append(nn.ReLU())
            if self.DNA[i][0] == "R":
                self.layers.append(nn.Dropout(self.DNA[i][1]))
        
        # Append the output layer        
        self.layers.append(nn.Linear(self.layers[-2].out_features, self.output_size))
        self.layers.append(nn.LogSoftmax(dim=1))
        self.net = nn.Sequential(*self.layers)
    def forward(self, x):
        return self.net(x)
    '''
    Based on the layers created, find the output size of the last dense layer
    '''
    def last_layer_output_size(self):
        for layer in self.layers[::-1]:
            if isinstance(layer, nn.Linear):
                return layer.out_features
        

In [16]:
# Now we copy the training code from the mnist example
# Training settings
batch_size = 16
use_mps = True and torch.backends.mps.is_available()
test_batch_size = 128
epochs = 14
lr = 1.0
gamma = 0.7
seed = 1
log_interval = 10
save_model = False

if use_mps:
    device = torch.device("mps")
else:
    device = torch.device("cpu")
    
train_kwargs = {'batch_size': batch_size}
test_kwargs = {'batch_size': test_batch_size}

transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),
        transforms.Lambda(lambda x: torch.flatten(x))
        ])
dataset1 = datasets.MNIST('../data', train=True, download=True,
                       transform=transform)
dataset2 = datasets.MNIST('../data', train=False,
                       transform=transform)
train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)


In [17]:
# Input length
input_size = len(dataset1[0][0].flatten())
output_size = 10 # fixed for our mnist example

In [18]:

model = ConstructNet(DNA, input_size, output_size ).to(device)
optimizer = optim.Adadelta(model.parameters(), lr=lr)

scheduler = StepLR(optimizer, step_size=1, gamma=gamma)


In [19]:
model

ConstructNet(
  (net): Sequential(
    (0): Linear(in_features=784, out_features=8192, bias=True)
    (1): ReLU()
    (2): Linear(in_features=8192, out_features=4096, bias=True)
    (3): ReLU()
    (4): Linear(in_features=4096, out_features=128, bias=True)
    (5): ReLU()
    (6): Linear(in_features=128, out_features=10, bias=True)
    (7): LogSoftmax(dim=1)
  )
)

In [10]:
def train(model, device, train_loader, optimizer, epoch, log_interval):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))


def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    
def count_parameters(model):
    # https://discuss.pytorch.org/t/how-do-i-check-the-number-of-parameters-of-a-model/4325/7
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [11]:
print(f"Parameter count: {count_parameters(model)}")

Parameter count: 40514954


In [12]:
for epoch in range(1, epochs + 1):
        train(model, device, train_loader, optimizer, epoch, log_interval)
        test(model, device, test_loader)
        scheduler.step()

if save_model:
    torch.save(model.state_dict(), "mnist_cnn.pt")


Test set: Average loss: 0.2980, Accuracy: 9429/10000 (94%)


Test set: Average loss: 0.2025, Accuracy: 9632/10000 (96%)


Test set: Average loss: 0.1664, Accuracy: 9707/10000 (97%)


Test set: Average loss: 0.1750, Accuracy: 9767/10000 (98%)


Test set: Average loss: 0.1963, Accuracy: 9793/10000 (98%)


Test set: Average loss: 0.2162, Accuracy: 9786/10000 (98%)


Test set: Average loss: 0.2159, Accuracy: 9809/10000 (98%)


Test set: Average loss: 0.2249, Accuracy: 9811/10000 (98%)


Test set: Average loss: 0.2195, Accuracy: 9816/10000 (98%)


Test set: Average loss: 0.2181, Accuracy: 9824/10000 (98%)


Test set: Average loss: 0.2179, Accuracy: 9826/10000 (98%)


Test set: Average loss: 0.2172, Accuracy: 9829/10000 (98%)


Test set: Average loss: 0.2166, Accuracy: 9828/10000 (98%)


Test set: Average loss: 0.2161, Accuracy: 9827/10000 (98%)


Dropout(p=0.2, inplace=False)