In [1]:
!rm -rf ./logs

In [2]:
import os
import subprocess
import signal

class TensorBoardServer():
    def __init__(self):
        pass
    
    def start(self):
        self.process = subprocess.Popen("tensorboard --logdir ./logs --host 127.0.0.1 --port 6006",
                                  shell=True, preexec_fn=os.setsid)

    def stop(self):
        os.killpg(self.process.pid, signal.SIGTERM)
        
tb_server = TensorBoardServer()
tb_server.start()

In [3]:
import datetime
import math
from mxboard import SummaryWriter
import mxnet as mx
from mxnet.gluon import nn
from mxnet.gluon.data.vision import transforms
import os

## Data

In [4]:
batch_size = 128

transform_fn = transforms.Compose([
    transforms.ToTensor()
])

train_dataset = mx.gluon.data.vision.CIFAR10(train=True).transform_first(transform_fn)
train_dataloader = mx.gluon.data.DataLoader(train_dataset, batch_size)

In [5]:
test_dataset = mx.gluon.data.vision.CIFAR10(train=False).transform_first(transform_fn)
test_dataloader = mx.gluon.data.DataLoader(test_dataset, batch_size)

## Model

In [6]:
class BasicBlock(nn.HybridBlock):
    """
    Pre-activation Residual Block with 2 convolution layers.
    """
    def __init__(self, channels, stride=1, dim_match=True):
        super(BasicBlock, self).__init__()
        self.stride = stride
        self.dim_match = dim_match
        with self.name_scope():
            self.bn1 = nn.BatchNorm(epsilon=2e-5)
            self.conv1 = nn.Conv2D(channels=channels, kernel_size=3, padding=1, strides=stride, use_bias=False)
            self.bn2 = nn.BatchNorm(epsilon=2e-5)
            self.conv2 = nn.Conv2D(channels=channels, kernel_size=3, padding=1, strides=1, use_bias=False)
            if not self.dim_match:
                self.conv3 = nn.Conv2D(channels=channels, kernel_size=1, padding=0, strides=stride, use_bias=False)

    def hybrid_forward(self, F, x):
        act1 = F.relu(self.bn1(x))
        act2 = F.relu(self.bn2(self.conv1(act1)))
        out = self.conv2(act2)
        if self.dim_match:
            shortcut = x
        else:
            shortcut = self.conv3(act1)
        return out + shortcut


class ResNet(nn.HybridBlock):
    def __init__(self, num_classes):
        super(ResNet, self).__init__()
        with self.name_scope():
            net = self.net = nn.HybridSequential()
            # data normalization
            net.add(nn.BatchNorm(epsilon=2e-5, scale=True))
            # pre-stage
            net.add(nn.Conv2D(channels=16, kernel_size=3, strides=1, padding=1, use_bias=False))
            # Stage 1 (4 total)
            net.add(BasicBlock(16, stride=1, dim_match=False))
            for _ in range(3):
                net.add(BasicBlock(16, stride=1, dim_match=True))
            # Stage 2 (4 total)
            net.add(BasicBlock(32, stride=2, dim_match=False))
            for _ in range(3):
                net.add(BasicBlock(32, stride=1, dim_match=True))
            # Stage 3 (4 in total)
            net.add(BasicBlock(64, stride=2, dim_match=False))
            for _ in range(3):
                net.add(BasicBlock(64, stride=1, dim_match=True))
            # post-stage (required as using pre-activation blocks)
            net.add(nn.BatchNorm(epsilon=2e-5))
            net.add(nn.Activation('relu'))
            net.add(nn.GlobalAvgPool2D())
            net.add(nn.Dense(num_classes))

    def hybrid_forward(self, F, x):
        out = x
        for i, b in enumerate(self.net):
            out = b(out)
        return out

# Training

In [7]:
def markdown_table(data):
    content = ""
    content += "Key  | Value" + "\n"
    content += "-----|-----" + "\n"
    for key, value in data.items():
        content += "{} | {}".format(key, value) + "\n"
    return content

In [8]:
def accuracy(output, label):
    output_argmax = output.argmax(axis=1).astype('int32')
    label_argmax = label.astype('int32')
    equal = output_argmax==label_argmax
    accuracy = mx.nd.mean(equal.astype('float32')).asscalar()
    return accuracy


def evaluate_accuracy(valid_data, model, ctx):
    acc = 0.
    count = 0
    for batch_idx, (data, label) in enumerate(valid_data):
        data = data.as_in_context(ctx)
        label = label.as_in_context(ctx)
        output = model(data)
        acc = acc + accuracy(output, label)
        count += 1
    return acc / count

In [9]:
def train_resnet(train_dataloader, test_dataloader, optimizer, description):
    run_id = datetime.datetime.now().strftime("%Y-%m-%d-%H:%M:%S") + "/" + description
    writer = SummaryWriter(logdir=os.path.join("./logs/cifar10", run_id))

    ctx = mx.gpu()
    kvstore = "device"
    
    net = ResNet(num_classes=10)
    # lazy initialize parameters
    net.initialize(mx.init.Xavier(), ctx=ctx)
    trainer = mx.gluon.Trainer(params=net.collect_params(), optimizer=optimizer, kvstore=kvstore)

    train_metric = mx.metric.Accuracy()
    loss_fn = mx.gluon.loss.SoftmaxCrossEntropyLoss()
    
    run_description = markdown_table({
        "batch_size": train_dataloader._batch_sampler._batch_size,
        "optimizer": type(optimizer),
        "optimizer_momentum": optimizer.momentum,
        "optimizer_wd": optimizer.wd
    })
    writer.add_text(tag='run_description', text=run_description, global_step=0)
    
    num_epochs = 10
    for epoch in range(1, num_epochs + 1):
        for batch_idx, (data_batch, label_batch) in enumerate(train_dataloader, start=1):
            # move to required context (e.g. gpu)
            data_batch = data_batch.as_in_context(ctx)
            label_batch = label_batch.as_in_context(ctx)
            # take forward and backward pass
            with mx.autograd.record():
                pred_batch = net(data_batch)
                loss = loss_fn(pred_batch, label_batch)
            loss.backward()
            bs = data_batch.shape[0]
            trainer.step(bs)
            train_metric.update(label_batch, pred_batch)

        # mxboard logging at end of each epoch

        ## sample of the images passed to network
        adj_data_batch = (data_batch - data_batch.min())/(data_batch.max() - data_batch.min())
        writer.add_image(tag="batch", image=adj_data_batch, global_step=epoch)

        ## histograms of input, output and loss
        writer.add_histogram(tag='input', values=data_batch, global_step=epoch, bins=100)
        writer.add_histogram(tag='output', values=pred_batch, global_step=epoch, bins=100)
        writer.add_histogram(tag='loss', values=loss, global_step=epoch, bins=100)

        ## learning rate
        writer.add_scalar(tag="learning_rate", value=trainer.learning_rate, global_step=epoch)
        
        ## training accuracy
        _, trn_acc = train_metric.get()
        writer.add_scalar(tag='accuracy/training', value=trn_acc * 100, global_step=epoch)
        
        ## test accuracy
        test_acc = evaluate_accuracy(test_dataloader, net, ctx)
        writer.add_scalar(tag='accuracy/testing', value=test_acc * 100, global_step=epoch)
        
        print("Completed epoch {}".format(epoch))

    writer.close()
    return net

In [10]:
lr_schedule = lambda iteration: min(iteration ** -0.5, iteration * 782 ** -1.5)
optimizer = mx.optimizer.SGD(lr_scheduler=lr_schedule)
trained_net = train_resnet(train_dataloader, test_dataloader, optimizer, description="baseline")

Completed epoch 1
Completed epoch 2
Completed epoch 3
Completed epoch 4
Completed epoch 5
Completed epoch 6
Completed epoch 7
Completed epoch 8
Completed epoch 9
Completed epoch 10


## Update 1: Shuffle training data

In [42]:
train_dataloader = mx.gluon.data.DataLoader(train_dataset, batch_size, shuffle=True)
optimizer = mx.optimizer.SGD(lr_scheduler=lr_schedule) # reset optimizer state (for momentum, lr schedule, etc)
trained_net = train_resnet(train_dataloader, test_dataloader, optimizer, description="w_shuffle")

Completed epoch 1
Completed epoch 2
Completed epoch 3
Completed epoch 4
Completed epoch 5
Completed epoch 6
Completed epoch 7
Completed epoch 8
Completed epoch 9
Completed epoch 10


## Update 2: Increase batch size

In [11]:
batch_size = batch_size * 4
train_dataloader = mx.gluon.data.DataLoader(train_dataset, batch_size, shuffle=True)
test_dataloader = mx.gluon.data.DataLoader(test_dataset, batch_size)

# lr_schedule = lambda iteration: min(iteration ** -0.5, iteration * 782 ** -1.5)
new_lr_schedule = lambda iteration: lr_schedule(iteration*4) * 4
optimizer = mx.optimizer.SGD(lr_scheduler=new_lr_schedule)

trained_net = train_resnet(train_dataloader, test_dataloader, optimizer, description="inc_bs")

Completed epoch 1
Completed epoch 2
Completed epoch 3
Completed epoch 4
Completed epoch 5
Completed epoch 6
Completed epoch 7
Completed epoch 8
Completed epoch 9
Completed epoch 10


## Update 3: Normalize data

In [44]:
transform_fn = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.4914, 0.4822, 0.4465],[0.2023, 0.1994, 0.2010])
])

train_dataset = mx.gluon.data.vision.CIFAR10(train=True).transform_first(transform_fn)
train_dataloader = mx.gluon.data.DataLoader(train_dataset, batch_size, shuffle=True)
test_dataset = mx.gluon.data.vision.CIFAR10(train=False).transform_first(transform_fn)
test_dataloader = mx.gluon.data.DataLoader(test_dataset, batch_size)

optimizer = mx.optimizer.SGD(lr_scheduler=new_lr_schedule)

trained_net = train_resnet(train_dataloader, test_dataloader, optimizer, description="normalized_input")

Completed epoch 1
Completed epoch 2
Completed epoch 3
Completed epoch 4
Completed epoch 5
Completed epoch 6
Completed epoch 7
Completed epoch 8
Completed epoch 9
Completed epoch 10


In [12]:
tb_server.stop()