# ipynb of train_2layer_withnn.py

In [1]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable

In [2]:
## define nn model
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

In [3]:
## define training function
def train(epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        if args.cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.data[0]))

In [4]:
## define test function
def test():
    model.eval()
    test_loss = 0
    correct = 0
    for data, target in test_loader:
        if args.cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data, volatile=True), Variable(target)
        output = model(data)
        test_loss += F.nll_loss(output, target, size_average=False).data[0] # sum up batch loss
        pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()

    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [43]:
class args():
    batch_size = 64
    test_batch_size = 1000
    epochs = 10
    lr = 0.01
    momentum = 0.5
    no_cuda = True
    seed = 1
    log_interval = 10
    cuda = not no_cuda and torch.cuda.is_available()

In [44]:
args = args()

In [45]:
# seed, argument setting
torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)

kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}

In [46]:
## download dataset 
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=args.batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=args.test_batch_size, shuffle=True, **kwargs)

In [47]:
train_loader.batch_size, train_loader.dataset

(64, <torchvision.datasets.mnist.MNIST at 0x7fcb467910b8>)

In [58]:
model = Net()

In [59]:
model

Net(
  (conv1): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(10, 20, kernel_size=(5, 5), stride=(1, 1))
  (conv2_drop): Dropout2d(p=0.5)
  (fc1): Linear(in_features=320, out_features=50, bias=True)
  (fc2): Linear(in_features=50, out_features=10, bias=True)
)

In [70]:
train_loader

<torch.utils.data.dataloader.DataLoader at 0x7fcb49cf9f98>

In [69]:
for batch_idx, (data, target) in enumerate(train_loader):
    if args.cuda:
        data, target = data.cuda(), target.cuda()
    data, target = Variable(data), Variable(target)
    
    print("batch_idx = {}, data shape = {}, target shape = {}".format(batch_idx, data.shape, target.shape))

batch_idx = 0, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 1, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 2, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 3, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 4, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 5, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 6, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 7, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 8, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 9, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 10, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 11, data

batch_idx = 95, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 96, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 97, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 98, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 99, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 100, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 101, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 102, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 103, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 104, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 105, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batc

batch_idx = 193, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 194, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 195, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 196, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 197, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 198, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 199, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 200, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 201, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 202, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 203, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])

batch_idx = 284, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 285, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 286, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 287, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 288, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 289, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 290, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 291, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 292, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 293, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 294, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])

batch_idx = 424, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 425, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 426, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 427, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 428, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 429, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 430, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 431, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 432, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 433, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 434, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])

batch_idx = 521, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 522, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 523, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 524, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 525, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 526, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 527, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 528, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 529, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 530, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 531, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])

batch_idx = 613, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 614, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 615, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 616, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 617, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 618, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 619, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 620, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 621, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 622, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 623, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])

batch_idx = 746, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 747, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 748, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 749, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 750, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 751, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 752, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 753, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 754, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 755, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 756, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])

batch_idx = 841, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 842, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 843, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 844, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 845, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 846, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 847, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 848, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 849, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 850, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])
batch_idx = 851, data shape = torch.Size([64, 1, 28, 28]), target shape = torch.Size([64])

In [73]:
N, D_in, H, D_out = 64, 784, 256, 10

In [74]:
model_new = torch.nn.Sequential(
              torch.nn.Linear(D_in, H),
              torch.nn.ReLU(),
              torch.nn.Linear(H, D_out),
            )

In [76]:
model_new

Sequential(
  (0): Linear(in_features=784, out_features=256, bias=True)
  (1): ReLU()
  (2): Linear(in_features=256, out_features=10, bias=True)
)

In [77]:
model

Net(
  (conv1): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(10, 20, kernel_size=(5, 5), stride=(1, 1))
  (conv2_drop): Dropout2d(p=0.5)
  (fc1): Linear(in_features=320, out_features=50, bias=True)
  (fc2): Linear(in_features=50, out_features=10, bias=True)
)

In [78]:
model.modules

<bound method Module.modules of Net(
  (conv1): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(10, 20, kernel_size=(5, 5), stride=(1, 1))
  (conv2_drop): Dropout2d(p=0.5)
  (fc1): Linear(in_features=320, out_features=50, bias=True)
  (fc2): Linear(in_features=50, out_features=10, bias=True)
)>

In [79]:
model_new.modules

<bound method Module.modules of Sequential(
  (0): Linear(in_features=784, out_features=256, bias=True)
  (1): ReLU()
  (2): Linear(in_features=256, out_features=10, bias=True)
)>