Source: https://github.com/bamos/densenet.pytorch by Brandon Amos, J. Zico Kolter.

Due to a bug of PyTorch, let's wait for a while before reruning this notebook.

## Setup

In [1]:
# default libraries
import os
import sys
import math
import argparse
from IPython.core.debugger import Tracer

In [2]:
# pytorch libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models

In [3]:
# global parameters
args = {
    "data": "../data/cifar/",
    "cuda": True,
    "seed": 7,
    "workers": 4,
    "optim": "sgd", # adam, rmsprop
    "epochs": 300,
    "batch_size": 64,
    "lr": 1e-1,
    "momentum": 0.9,
    "weight_decay": 1e-4,
    "intermediate": "../intermediate/densenet/"
}
args = argparse.Namespace(**args)

if not os.path.isdir(args.data):
    !mkdir $args.data
    
if not os.path.isdir(args.intermediate):
    !mkdir $args.intermediate

args.cuda = args.cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)

## Define the model

In [4]:
# define some layers
class Bottleneck(nn.Module):
    """BN-ReLU-Conv(1x1)-BN-ReLU-Conv(3x3)"""
    def __init__(self, nChannels, growthRate):
        # it is necessary to init nn.Module
        super(Bottleneck, self).__init__()
        
        interChannels = 4*growthRate # as in the paper
        self.bn1 = nn.BatchNorm2d(nChannels) # batch normalization
        self.conv1 = nn.Conv2d(nChannels, interChannels, kernel_size=1,
                               bias=False)
        self.bn2 = nn.BatchNorm2d(interChannels)
        self.conv2 = nn.Conv2d(interChannels, growthRate, kernel_size=3,
                               padding=1, bias=False)

    def forward(self, x):
        # F is function without weights!
        out = self.conv1(F.relu(self.bn1(x)))
        out = self.conv2(F.relu(self.bn2(out)))
        out = torch.cat((x, out), 1)
        return out

    
class SingleLayer(nn.Module):
    """BN-ReLU-Conv(3x3)"""
    def __init__(self, nChannels, growthRate):
        super(SingleLayer, self).__init__()
        
        self.bn1 = nn.BatchNorm2d(nChannels)
        self.conv1 = nn.Conv2d(nChannels, growthRate, kernel_size=3,
                               padding=1, bias=False)

    def forward(self, x):
        out = self.conv1(F.relu(self.bn1(x)))
        out = torch.cat((x, out), 1)
        return out

    
class Transition(nn.Module):
    """Bn-Conv(1x1)-Pooling(2x2)"""
    def __init__(self, nChannels, nOutChannels):
        super(Transition, self).__init__()
        self.bn1 = nn.BatchNorm2d(nChannels)
        self.conv1 = nn.Conv2d(nChannels, nOutChannels, kernel_size=1,
                               bias=False)

    def forward(self, x):
        out = self.conv1(F.relu(self.bn1(x)))
        out = F.avg_pool2d(out, 2) # pooling has no weight
        return out

In [5]:
# main model
class DenseNet(nn.Module):
    def __init__(self, growthRate, depth, reduction, nClasses, bottleneck):
        super(DenseNet, self).__init__()

        nDenseBlocks = (depth-4) // 3
        if bottleneck:
            nDenseBlocks //= 2
            nChannels = 2*growthRate
        else:
            nChannels = 16
        
        self.conv1 = nn.Conv2d(3, nChannels, kernel_size=3, padding=1,
                               bias=False)
        self.dense1 = self._make_dense(nChannels, growthRate, nDenseBlocks,
                                       bottleneck)
        nChannels += nDenseBlocks*growthRate
        nOutChannels = int(math.floor(nChannels*reduction))
        self.trans1 = Transition(nChannels, nOutChannels)

        nChannels = nOutChannels
        self.dense2 = self._make_dense(nChannels, growthRate, nDenseBlocks,
                                       bottleneck)
        nChannels += nDenseBlocks*growthRate
        nOutChannels = int(math.floor(nChannels*reduction))
        self.trans2 = Transition(nChannels, nOutChannels)

        nChannels = nOutChannels
        self.dense3 = self._make_dense(nChannels, growthRate, nDenseBlocks,
                                       bottleneck)
        nChannels += nDenseBlocks*growthRate

        self.bn1 = nn.BatchNorm2d(nChannels)
        self.fc = nn.Linear(nChannels, nClasses)
        
        # we initialize weights here
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.bias.data.zero_()
    
    # dense block
    def _make_dense(self, nChannels, growthRate, nDenseBlocks, bottleneck):
        layers = []
        for i in range(int(nDenseBlocks)):
            if bottleneck:
                layers.append(Bottleneck(nChannels, growthRate))
            else:
                layers.append(SingleLayer(nChannels, growthRate))
            nChannels += growthRate
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x)
        out = self.trans1(self.dense1(out))
        out = self.trans2(self.dense2(out))
        out = self.dense3(out)
        # squeeze = flaten
        out = torch.squeeze(F.avg_pool2d(F.relu(self.bn1(out)), 8))
        out = F.log_softmax(self.fc(out))
        return out

## Utility functions

In [6]:
def adjust_opt(optAlg, optimizer, epoch):
    if optAlg == "sgd":
        if epoch == 150: lr = 1e-2
        elif epoch == 225: lr = 1e-3
        else: return

        for param_group in optimizer.param_groups:
            param_group["lr"] = lr

## Train/test functions

In [7]:
def train(args, epoch, net, trainLoader, optimizer, trainF):
    net.train() # effect on Dropout or BatchNorm
    nProcessed = 0 # number of data has been processed
    nTrain = len(trainLoader.dataset)
    for batch_idx, (data, target) in enumerate(trainLoader):
        if args.cuda:
            data, target = data.cuda(async=True), target.cuda(async=True)
        data, target = Variable(data), Variable(target)
        
        output = net(data)
        # because output is log-probability, the real loss will be
        # the log-probability at target!
        loss = F.nll_loss(output, target)
        # make_graph.save('/tmp/t.dot', loss.creator); assert(False)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        nProcessed += len(data)
        pred = output.data.max(1)[1] # get the index of the max log-probability
        incorrect = pred.ne(target.data).cpu().sum()
        err = 100.*incorrect/len(data)
        partialEpoch = epoch + batch_idx / len(trainLoader) - 1
        if batch_idx % 100 == 0:
            print("Train Epoch: {:.2f} [{}/{} ({:.0f}%)]\t"
                  "Loss: {:.6f}\tError: {:.6f}"
                  .format(partialEpoch, nProcessed, nTrain,
                          100. * batch_idx / len(trainLoader),
                          loss.data[0], err))

        trainF.write("{},{},{}\n".format(partialEpoch, loss.data[0], err))
        trainF.flush()

In [8]:
def test(args, epoch, net, testLoader, optimizer, testF):
    net.eval()
    test_loss = 0
    incorrect = 0
    for data, target in testLoader:
        if args.cuda:
            data, target = data.cuda(async=True), target.cuda(async=True)
        data, target = Variable(data, volatile=True), Variable(target)
        output = net(data)
        test_loss += F.nll_loss(output, target).data[0]
        pred = output.data.max(1)[1] # get the index of the max log-probability
        incorrect += pred.ne(target.data).cpu().sum()

    test_loss = test_loss
    # loss function already averages over batch size
    test_loss /= len(testLoader)
    nTotal = len(testLoader.dataset)
    err = 100.*incorrect/nTotal
    print("\nTest set: Average loss: {:.4f}, Error: {}/{} ({:.0f}%)\n"
          .format(test_loss, incorrect, nTotal, err))

    testF.write("{},{},{}\n".format(epoch, test_loss, err))
    testF.flush()

## Prepare data

In [9]:
normMean = [0.49139968, 0.48215827, 0.44653124]
normStd = [0.24703233, 0.24348505, 0.26158768]
normTransform = transforms.Normalize(normMean, normStd)

trainTransform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    normTransform
])
testTransform = transforms.Compose([
    transforms.ToTensor(),
    normTransform
])

kwargs = {"pin_memory": True} if args.cuda else {}
trainLoader = DataLoader(
    datasets.CIFAR10(root=args.data, train=True, download=True,
                     transform=trainTransform),
    batch_size=args.batch_size, shuffle=True, num_workers=args.workers,
    **kwargs
)
testLoader = DataLoader(
    datasets.CIFAR10(root=args.data, train=False,
                     transform=testTransform),
    batch_size=args.batch_size, shuffle=False, num_workers=args.workers,
    **kwargs
)

Files already downloaded and verified
Files already downloaded and verified


## Run the model

In [10]:
# create the model
net = DenseNet(growthRate=12, depth=100, reduction=0.5,
                        bottleneck=True, nClasses=10)
print("  + Number of params: {}".format(
        sum([p.data.nelement() for p in net.parameters()])))
if args.cuda:
    net = net.cuda()

  + Number of params: 769162


In [11]:
# define the optimizer
if args.optim == "sgd":
    optimizer = optim.SGD(net.parameters(),
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
elif args.optim == "adam":
    optimizer = optim.Adam(net.parameters(),
                           lr=args.lr,
                           weight_decay=args.weight_decay)
elif args.optim == "rmsprop":
    optimizer = optim.RMSprop(net.parameters(),
                              lr=args.lr,
                              weight_decay=args.weight_decay)

In [12]:
# save information
trainF = open(os.path.join(args.intermediate, "train.csv"), "w")
testF = open(os.path.join(args.intermediate, "test.csv"), "w")

In [13]:
# train
for epoch in range(1, args.epochs + 1):
    adjust_opt(args.optim, optimizer, epoch)
    train(args, epoch, net, trainLoader, optimizer, trainF)
    test(args, epoch, net, testLoader, optimizer, testF)
    torch.save(net, os.path.join(args.intermediate, "latest.pth"))

trainF.close()
testF.close()


Test set: Average loss: 1.8923, Error: 7016/10000 (70%)



  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "



Test set: Average loss: 1.8319, Error: 6872/10000 (69%)


Test set: Average loss: 2.2125, Error: 7778/10000 (78%)


Test set: Average loss: 2.3071, Error: 9000/10000 (90%)


Test set: Average loss: 2.3072, Error: 9000/10000 (90%)


Test set: Average loss: 2.3066, Error: 9000/10000 (90%)


Test set: Average loss: 2.3055, Error: 9000/10000 (90%)


Test set: Average loss: 2.3069, Error: 9000/10000 (90%)


Test set: Average loss: 2.3112, Error: 9000/10000 (90%)


Test set: Average loss: 2.3054, Error: 9000/10000 (90%)


Test set: Average loss: 2.3069, Error: 9000/10000 (90%)


Test set: Average loss: 2.3094, Error: 9000/10000 (90%)


Test set: Average loss: 2.3077, Error: 9000/10000 (90%)



Process Process-105:
Process Process-106:
Process Process-108:
Process Process-107:
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/fehiepsi/miniconda3/envs/pydata/lib/python3.5/multiprocessing/process.py", line 249, in _bootstrap
    self.run()
Traceback (most recent call last):
  File "/home/fehiepsi/miniconda3/envs/pydata/lib/python3.5/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
Traceback (most recent call last):
  File "/home/fehiepsi/miniconda3/envs/pydata/lib/python3.5/site-packages/torch/utils/data/dataloader.py", line 26, in _worker_loop
    r = index_queue.get()
  File "/home/fehiepsi/miniconda3/envs/pydata/lib/python3.5/multiprocessing/process.py", line 249, in _bootstrap
    self.run()
  File "/home/fehiepsi/miniconda3/envs/pydata/lib/python3.5/multiprocessing/process.py", line 249, in _bootstrap
    self.run()
  File "/home/fehiepsi/miniconda3/envs/pydata/lib/python3.5/multiprocessing/process

KeyboardInterrupt: 