### Prepare
#### List GPU

In [None]:
!nvidia-smi

#### Select GPUs to run

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '2'

#### Import package and config

In [None]:
# MXNet package
from mxnet import nd, init, cpu, gpu, gluon, autograd
from mxnet.gluon import nn
from mxnet.gluon.data import DataLoader
from mxnet.gluon.data.vision import CIFAR100, transforms as T
from mxnet.gluon.model_zoo.vision import resnet50_v1, mobilenet1_0, mobilenet_v2_1_0

# Normal package
import time
import logging
from tensorboardX import SummaryWriter

# Custom package
import sys
sys.path.append("..")
from utils import Timer, ModelConfig, QTrainConfig
from quantize.convert import convert_model
from quantize.initialize import qparams_init as qinit

In [None]:
logging.basicConfig(level=logging.INFO, format='%(asctime)s - [%(levelname)s]%(name)s: %(message)s')

In [None]:
mconfig = ModelConfig(
    num_class=20
)
QUANT_OFFLINE_AFTER = 10000
MAIN_TAG = "cifar100_resnet50_v1_quant_relu6"
tconfig = QTrainConfig(
    tb_main_tag=MAIN_TAG,
    checkpoint_prefix=MAIN_TAG,
    param_file="./models/cifar100_resnet50-007000.params",
    train_batch_size=16,
    val_batch_size=16,
    learning_rate=1e-6,
    quant_offline_after=QUANT_OFFLINE_AFTER,
    spotter_starts_at=QUANT_OFFLINE_AFTER+2000
)

#### Create a writer for tensorboard     
Named with datetime and main tag.

In [None]:
datetime_stamp = time.strftime('%Y%m%d_%H%M%S',time.localtime(time.time()))
writer = SummaryWriter(log_dir="runs/{}_{}".format(tconfig.tb_main_tag, datetime_stamp))

#### Construct the model       
1. Define the model       
2. Load parameters      
3. Convert to quantized version
4. Initialize new parameters

In [None]:
# net = mobilenet1_0(pretrained=False)
# net.output = nn.Dense(mconfig.num_class, weight_initializer='normal', bias_initializer='zeros')

net = resnet50_v1(pretrained=False)
net.output = nn.Dense(mconfig.num_class, weight_initializer='normal', bias_initializer='zeros')

net.load_parameters(tconfig.param_file)
convert_model(net, exclude=[net.features[0]])  # exclude the first conv
qinit(net)
_ = net.collect_params().reset_ctx(gpu(0))

#### Define a evaluate function

In [None]:
def evaluate(net, num_class, dataloader, ctx):
    t = time.time()
    correct_counter = nd.zeros(num_class)
    label_counter = nd.zeros(num_class)
    test_num_correct = 0
    eval_loss = 0.

    for X, y in dataloader:
        X = X.as_in_context(ctx)
        y = y.as_in_context(ctx)

        outputs = net(X)
        loss = loss_func(outputs, y)
        eval_loss += loss.sum().asscalar()
        pred = outputs.argmax(axis=1)
        test_num_correct += (pred == y.astype('float32')).sum().asscalar()

        pred = pred.as_in_context(cpu())
        y = y.astype('float32').as_in_context(cpu())
        for p, gt in zip(pred, y):
            label_counter[gt] += 1
            if p == gt:
                correct_counter[gt] += 1

    eval_loss /= len(test_dataset)
    eval_acc = test_num_correct / len(test_dataset)
    eval_acc_avg = (correct_counter / (label_counter+1e-10)).mean().asscalar()
    
    return eval_loss, eval_acc, eval_acc_avg, time.time()-t

#### Dataset & Dataloader

In [None]:
# no augmentation
train_transformer = T.Compose([
    T.Resize(224),
    T.ToTensor(),
    T.Normalize([.5, .5, .5], [.5, .5, .5])
])

eval_transformer = T.Compose([
    T.Resize(224),
    T.ToTensor(),
    T.Normalize([.5, .5, .5], [.5, .5, .5])
])

In [None]:
train_dataset = CIFAR100(train=True, transform=lambda data,label:(train_transformer(data),label))
train_loader = DataLoader(dataset=train_dataset, batch_size=tconfig.train_batch_size, shuffle=True, 
                          num_workers=tconfig.train_num_prefetch_workers, last_batch='discard')
test_dataset = CIFAR100(train=False, transform=lambda data,label:(eval_transformer(data), label))
test_loader = DataLoader(dataset=test_dataset, batch_size=tconfig.val_batch_size, shuffle=False, 
                          num_workers=tconfig.val_num_prefetch_workers, last_batch='keep')

tconfig.summary(train_dataset, test_dataset)

## Train

In [None]:
global_steps = 0
good_acc_window = [0.]*tconfig.spotter_window_size
estop_loss_window = [0.]*tconfig.patience
quantize_input_offline = False

In [None]:
loss_func = gluon.loss.SoftmaxCrossEntropyLoss()
trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate': tconfig.learning_rate})

In [None]:
t = Timer()
size_per_record = tconfig.train_record_per_steps*tconfig.train_batch_size
flag_early_stop = False
train_loss = 0.
train_num_correct = 0
t.start()
while global_steps < tconfig.max_steps and not flag_early_stop:
    for X, y in train_loader:
        # Move data to gpu
        X = X.as_in_context(gpu(0))
        y = y.as_in_context(gpu(0))
        # Forward & Backward
        with autograd.record():
            outputs = net(X)
            loss = loss_func(outputs, y)
        net.update_ema()
        loss.backward()
        trainer.step(tconfig.train_batch_size, ignore_stale_grad=True)   # ignore origin batchnorm paramters
        train_loss += loss.sum().asscalar()
        pred = outputs.argmax(axis=1)
        train_num_correct += (pred == y.astype('float32')).sum().asscalar()
        
        # Record training info
        if global_steps and global_steps % tconfig.train_record_per_steps == 0:
            t.stop()
            writer.add_scalar('{}/train/Speed'.format(tconfig.tb_main_tag), 1.*tconfig.train_record_per_steps/t.pop(), global_steps)
            writer.add_scalars('{}/Loss'.format(tconfig.tb_main_tag), {'train': train_loss*len(test_dataset)/size_per_record}, global_steps)
            writer.add_scalars('{}/Acc'.format(tconfig.tb_main_tag), {'train': train_num_correct/size_per_record}, global_steps)
            train_loss = 0.
            train_num_correct = 0
            t.start()
        # Evaluate
        if global_steps and global_steps % tconfig.val_per_steps == 0:
            t.stop()
            eval_loss, eval_acc, eval_acc_avg, __ = evaluate(net, mconfig.num_class, test_loader, ctx=gpu(0))
            writer.add_scalars('{}/Loss'.format(tconfig.tb_main_tag), {'val': eval_loss}, global_steps)
            writer.add_scalars('{}/Acc'.format(tconfig.tb_main_tag), {
                'val': eval_acc,
                'val_avg': eval_acc_avg
            }, global_steps)
            
            # Spotter & Patience
            if quantize_input_offline:
                # Spotter
                good_acc_window.pop(0)
                if global_steps >= tconfig.spotter_starts_at and eval_acc > max(good_acc_window):
                    print( "catch a good model with acc {:.6f} at {} step".format(eval_acc, global_steps) )
                    writer.add_text(tconfig.tb_main_tag, "catch a good model with acc {:.6f}".format(eval_acc), global_steps)
                    net.save_parameters("{}/{}-{:06d}.params".format(tconfig.checkpoint_dir, tconfig.checkpoint_prefix, global_steps))
                good_acc_window.append(eval_acc)
                
                # Early stop
                estop_loss_window.pop(0)
                estop_loss_window.append(eval_loss)
                if global_steps > tconfig.val_per_steps*len(estop_loss_window):
                    min_index = estop_loss_window.index( min(estop_loss_window) )
                    writer.add_scalar('{}/val/Patience'.format(tconfig.tb_main_tag), min_index, global_steps)
                    if min_index == 0:
                        flag_early_stop = True
                        print("early stop at {} steps".format(global_steps))
                        break
            
            t.start()
        
        # Snapshot
        if global_steps and global_steps % tconfig.snapshot_per_steps == 0:
            net.save_parameters("{}/{}-{:06d}.params".format(tconfig.checkpoint_dir, tconfig.checkpoint_prefix, global_steps))
        
        # Quantize input offline
        if(global_steps == tconfig.quant_offline_after):
            quantize_input_offline = True
            net.quantize_input_offline()
            
        # Next step
        global_steps += 1

In [None]:
writer.close()
exit(0)