In [None]:
!nvidia-smi

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [None]:
# MXNet package
from mxnet import nd, init, cpu, gpu, gluon, autograd
from mxnet.gluon import nn
from mxnet.gluon.data import DataLoader, Sampler
from mxnet.gluon.data.vision import CIFAR10, transforms as T
from gluoncv.data import transforms as gcv_T
from gluoncv.model_zoo import cifar_resnet56_v1

# Normal package
import time
from tensorboardX import SummaryWriter
import numpy as np

# Custom package
import sys
sys.path.append("..")
from quantize import convert
from quantize.initialize import qparams_init

In [None]:
class Config(object):
    # Model
    num_class = 10
    
    # Train
    max_steps = 8000
    train_batch_size = 64
    val_batch_size = 128
    train_num_workers = 4
    val_num_workers = 4
    lr = 1e-6
    
    # Record
    ckpt_dir = "./tmp/checkpoints"
    main_tag = 'cifar_resnet56_v1_quantize'
    ckpt_prefix = 'cifar_resnet56_v1_quantize'
    train_record_per_steps = 200
    val_per_steps = 400
    spotter_starts_at = 10000
    spotter_window_size = 10
    patience = 20
    snapshot_per_steps = 400
    
    # Quantize
    offline_at = 3000

In [None]:
if not os.path.exists(Config.ckpt_dir):
    os.mkdir(Config.ckpt_dir)

In [None]:
datetime_stamp = time.strftime('%Y%m%d_%H%M%S',time.localtime(time.time()))
writer = SummaryWriter(log_dir="tmp/runs/{}_{}".format(Config.main_tag, datetime_stamp))

In [None]:
# Quantize inputs
converter = {
    nn.Conv2D: convert.gen_conv2d_converter(quant_type="channel", fake_bn=True, input_width=4, weight_width=4),
    nn.Dense: convert.gen_dense_converter(quant_type="channel"),
    nn.Activation: None,
    nn.BatchNorm: convert.bypass_bn
}

In [None]:
net = cifar_resnet56_v1(pretrained=True)
convert.convert_model(net, exclude=[net.features[0], net.features[1]], convert_fn=converter)
net.quantize_input(enable=False)
qparams_init(net)
net.collect_params().reset_ctx(gpu(0))

In [None]:
def evaluate(net, num_class, dataloader, ctx):
    t = time.time()
    correct_counter = nd.zeros(num_class)
    label_counter = nd.zeros(num_class)
    test_num_correct = 0
    eval_loss = 0.

    for X, y in dataloader:
        X = X.as_in_context(ctx)
        y = y.as_in_context(ctx)

        outputs = net(X)
        loss = loss_func(outputs, y)
        eval_loss += loss.sum().asscalar()
        pred = outputs.argmax(axis=1)
        test_num_correct += (pred == y.astype('float32')).sum().asscalar()

        pred = pred.as_in_context(cpu())
        y = y.astype('float32').as_in_context(cpu())
        for p, gt in zip(pred, y):
            label_counter[gt] += 1
            if p == gt:
                correct_counter[gt] += 1

    eval_loss /= len(test_dataset)
    eval_acc = test_num_correct / len(test_dataset)
    eval_acc_avg = (correct_counter / (label_counter+1e-10)).mean().asscalar()
    
    return eval_loss, eval_acc, eval_acc_avg, time.time()-t

In [None]:
train_transformer = T.Compose([
    gcv_T.RandomCrop(32, pad=4),
    T.RandomFlipLeftRight(),
    T.ToTensor(),
    T.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
])

eval_transformer = T.Compose([
    T.ToTensor(),
    T.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
])

In [None]:
train_dataset = CIFAR10(train=True).transform_first(train_transformer)
train_loader = DataLoader(dataset=train_dataset, 
                          batch_size=Config.train_batch_size,
                          num_workers=Config.train_num_workers,
                          last_batch='discard')
test_dataset = CIFAR10(train=False).transform_first(eval_transformer)
test_loader = DataLoader(dataset=test_dataset, 
                         batch_size=Config.val_batch_size, 
                         shuffle=False,
                         num_workers=Config.val_num_workers, 
                         last_batch='keep')

In [None]:
# Summary
train_size = len(train_dataset)
val_size = len(test_dataset)
print(f'trainset size => {train_size}')
print(f'valset size => {val_size}')
steps_per_epoch = train_size / Config.train_batch_size
print(f'{steps_per_epoch} steps for per epoch (BATCH_SIZE={Config.train_batch_size})')
print("record per {} steps ({} samples, {} times per epoch)".format(
                                                            Config.train_record_per_steps,
                                                            Config.train_record_per_steps * Config.train_batch_size,
                                                            steps_per_epoch / Config.train_record_per_steps))
print("evaluate per {} steps ({} times per epoch)".format(
                                                    Config.val_per_steps,
                                                    steps_per_epoch / Config.val_per_steps))
print("spotter start at {} steps ({} epoches)".format(
                                                Config.spotter_starts_at,
                                                Config.spotter_starts_at / steps_per_epoch))
print("size of spotter window is {} ({} steps)".format(
                                                Config.spotter_window_size,
                                                Config.spotter_window_size * Config.val_per_steps))
print("max patience: {} ({} steps; {} samples; {} epoches)".format(
                                                            Config.patience,
                                                            Config.patience * Config.val_per_steps,
                                                            Config.patience * Config.val_per_steps * Config.train_batch_size,
                                                            Config.patience * Config.val_per_steps / steps_per_epoch))
print("snapshot per {} steps ({} times per epoch)".format(
                                                    Config.snapshot_per_steps,
                                                    steps_per_epoch / Config.snapshot_per_steps))

In [None]:
global_steps = 0
good_acc_window = [0.] * Config.spotter_window_size
estop_loss_window = [0.] * Config.patience
quantize_offline = False

In [None]:
loss_func = gluon.loss.SoftmaxCrossEntropyLoss()
trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate': Config.lr})

In [None]:
# First Evaluate
eval_loss, eval_acc, eval_acc_avg, cost_time = evaluate(net, Config.num_class, test_loader, ctx=gpu(0))
writer.add_scalars(f'{Config.main_tag}/Loss', {'val': eval_loss}, global_steps)
writer.add_scalars(f'{Config.main_tag}/Acc', {
    'val': eval_acc,
    'val_avg': eval_acc_avg
}, global_steps)
print(f"Evaluate cost time: {cost_time}")

In [None]:
size_per_record = Config.train_record_per_steps * Config.train_batch_size
flag_early_stop = False
train_loss = 0.
train_num_correct = 0
prune_counter = 0
t = time.time()
while global_steps < Config.max_steps and not flag_early_stop:
    for X, y in train_loader:
        global_steps += 1
        # Move data to gpu
        X = X.as_in_context(gpu(0))
        y = y.as_in_context(gpu(0))
        # Forward & Backward
        with autograd.record():
            outputs = net(X)
            loss = loss_func(outputs, y)
        net.update_ema()
        loss.backward()
#         trainer.step(Config.train_batch_size)
        trainer.step(Config.train_batch_size, ignore_stale_grad=True)     # if bypass bn
        
        train_loss += loss.sum().asscalar()
        pred = outputs.argmax(axis=1)
        train_num_correct += (pred == y.astype('float32')).sum().asscalar()
        
        # Record training info
        if global_steps and global_steps % Config.train_record_per_steps == 0:
            writer.add_scalars(f'{Config.main_tag}/Loss', {'train': train_loss/size_per_record}, global_steps)
            writer.add_scalars(f'{Config.main_tag}/Acc', {'train': train_num_correct/size_per_record}, global_steps)
            train_loss = 0.
            train_num_correct = 0
            
        # Evaluate
        if global_steps and global_steps % Config.val_per_steps == 0:
            # Quantize
            if not quantize_offline and global_steps >= Config.offline_at:
                print("Quantize offline...")
                net.quantize_input(enable=True, online=False)
                quantize_offline = True
            
            # Evaluate
            eval_loss, eval_acc, eval_acc_avg, __ = evaluate(net, Config.num_class, test_loader, ctx=gpu(0))
            writer.add_scalar(f'{Config.main_tag}/Speed', Config.val_per_steps / (time.time() - t), global_steps)
            writer.add_scalars(f'{Config.main_tag}/Loss', {'val': eval_loss}, global_steps)
            writer.add_scalars(f'{Config.main_tag}/Acc', {
                'val': eval_acc,
                'val_avg': eval_acc_avg
            }, global_steps)
            
            # Spotter
            good_acc_window.pop(0)
            if global_steps >= Config.spotter_starts_at and eval_acc > max(good_acc_window):
                print( "catch a good model with acc {:.6f} at {} step".format(eval_acc, global_steps) )
                writer.add_text(Config.main_tag, "catch a good model with acc {:.6f}".format(eval_acc), global_steps)
                net.save_parameters("{}/{}-{:06d}.params".format(Config.ckpt_dir, Config.ckpt_prefix, global_steps))
            good_acc_window.append(eval_acc)

            # Early stop
            estop_loss_window.pop(0)
            estop_loss_window.append(eval_loss)
            if global_steps > Config.val_per_steps * len(estop_loss_window):
                min_index = estop_loss_window.index( min(estop_loss_window) )
                writer.add_scalar(f'{Config.main_tag}/val/Patience', min_index, global_steps)
                if min_index == 0:
                    flag_early_stop = True
                    print("early stop at {} steps".format(global_steps))
                    break
            
            t = time.time()
        
        # Snapshot
        if global_steps and global_steps % Config.snapshot_per_steps == 0:
            net.save_parameters("{}/{}-{:06d}.params".format(Config.ckpt_dir, Config.ckpt_prefix, global_steps))

In [None]:
exit()