This notebook is meant for providing a detailed description of MetaQuant via code and mathematical explanation.

## Motivation
Training-based quantization aims at minimizing the following training loss:

$$\min \ell = \text{Loss}(f(Q(\mathbf{W}, \mathbf{x})))$$

where $Q(\cdot)$ quantize full-precision weight $\mathbf{W}$ into quantized value $\mathbf{\hat{W}}$. Due to the non-differentiability of $Q(\cdot)$, the gradient of $\ell$ w.r.t $\mathbf{W}$ cannot be attained in a normal way.
To enable a stable quantization training, Straight-Through-Estimator (STE) is proposed to redefine $\partial Q(r)/\partial r$:

\begin{eqnarray}
	\frac{\partial Q(r)}{\partial r} =\left\{
	\begin{aligned}
	1 & \qquad \text{if} \qquad |r| \leq 1  \nonumber \\
	0 & \qquad \text{else}.
	\end{aligned}
	\right..
\end{eqnarray}

However, it inevitably brings the problem of **gradient mismatch**: the gradients of the weights are not generated using the value of weights, but rather its quantized value. Although STE provides an end-to-end training method under discrete constraints, few works have progressed to investigate how to obtain better gradients for quantization training. 

To overcome the problem of gradient mismatch and explore better gradients in training-based methods, we propose to learn $\frac{\partial Q(\mathbf{W})}{\partial \mathbf{W}}$ by a neural network ($\mathcal{W}$) during quantization training. Such neural network is called **meta quantizer** and is trained together with the base quantized model. This process is named as **Meta** **Quant**ization (MetaQuant). 

Specially, in each backward propagation, $\mathcal{W}$ takes $\frac{\partial \ell}{\partial Q(\mathbf{W})}$ and $\mathbf{W}$ as inputs in a coordinate-wise manner, then its output is assigned to $\frac{\partial \ell}{\partial Q(\mathbf{W})}$ for weights update using common optimization methods such as SGD and Adam. In the forward pass, inference is conducted using the quantized version of the updated weights, which produce the final outputs to be compared with the ground-truth labels for backward computation. During this process, gradient propagation from the quantized weights to the full-precision weights is handled by $\mathcal{M}$, which avoids the problem of non-differentiability and gradient mismatch. Besides, the gradients generated by the \meta are loss-aware, contributing to better performance of the quantization training.

## Overflow of MetaQuant
<!--
![Overflow of MetaQuant]('./figs/MetaQuant.png')
-->

<img src="figs/MetaQuant.png">

In [None]:
"""
This block import packages and initialize key parameters
"""
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import shutil
import pickle
import time
import numpy as np

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim as optim

from utils.dataset import get_dataloader
from meta_utils.meta_network import MetaFC, MetaLSTMFC, MetaDesignedMultiFC
from meta_utils.SGD import SGD
from meta_utils.adam import Adam
from meta_utils.helpers import meta_gradient_generation, update_parameters
from utils.recorder import Recorder
from utils.miscellaneous import AverageMeter, accuracy, progress_bar
from utils.miscellaneous import get_layer
from utils.quantize import test

from models_CIFAR.quantized_meta_resnet import resnet20_cifar

# ------------------------------------------
use_cuda = torch.cuda.is_available()
model_name = 'ResNet20'
dataset_name = 'CIFAR10'
meta_method = 'MultiFC' # ['LSTMFC', 'MultiFC', 'FC-Grad']
MAX_EPOCH = 100
optimizer_type = 'SGD' # ['SGD', 'SGD-M', 'adam']
hidden_size = 100
num_lstm = 2
num_fc = 3
lr_adjust = '30'
batch_size = 128
bitW = 1
quantized_type = 'dorefa' # ['dorefa', 'BWN']
save_root = './Results/%s-%s' % (model_name, dataset_name)
# ------------------------------------------

In [None]:
"""
This block initialize network and load dataset
"""

import utils.global_var as gVar
gVar.meta_count = 0

###################
# Initial Network #
###################
net = resnet20_cifar(bitW=bitW)
pretrain_path = '%s/%s-%s-pretrain.pth' % (save_root, model_name, dataset_name)
net.load_state_dict(torch.load(pretrain_path), strict=False)

# Get layer name list
layer_name_list = net.layer_name_list
# Assert all required layer is initialized as meta layer
assert (len(layer_name_list) == gVar.meta_count)
print('Layer name list completed.')

if use_cuda:
    net.cuda()
    
################
# Load Dataset #
################
train_loader = get_dataloader(dataset_name, 'train', batch_size)
test_loader = get_dataloader(dataset_name, 'test', 100)

In [None]:
"""
This block initialize meta network, optimizer and recorder
"""
########################
# Initial Meta Network #
########################
if meta_method == 'LSTMFC':
    meta_net = MetaLSTMFC(hidden_size=hidden_size)
    SummaryPath = '%s/runs-Quant/Meta-%s-Nonlinear-%s-' \
                  'hidden-size-%d-nlstm-1-%s-%s-%dbits-lr-%s' \
                  % (save_root, meta_method, args.meta_nonlinear, hidden_size,
                     quantized_type, optimizer_type, bitW, lr_adjust)
elif meta_method in ['FC-Grad']:
    meta_net = MetaFC(hidden_size=hidden_size, use_nonlinear=args.meta_nonlinear)
    SummaryPath = '%s/runs-Quant/Meta-%s-Nonlinear-%s-' \
                  'hidden-size-%d-%s-%s-%dbits-lr-%s' \
                  % (save_root, meta_method, args.meta_nonlinear, hidden_size,
                     quantized_type, optimizer_type, bitW, lr_adjust)
elif meta_method == 'MultiFC':
    meta_net = MetaDesignedMultiFC(hidden_size=hidden_size,
                                   num_layers = args.num_fc,
                                   use_nonlinear=args.meta_nonlinear)
    SummaryPath = '%s/runs-Quant/Meta-%s-Nonlinear-%s-' \
                  'hidden-size-%d-nfc-%d-%s-%s-%dbits-lr-%s' \
                  % (save_root, meta_method, args.meta_nonlinear, hidden_size, num_fc,
                     quantized_type, optimizer_type, bitW, lr_adjust)
else:
    raise NotImplementedError

print(meta_net)

if use_cuda:
    meta_net.cuda()

meta_optimizer = optim.Adam(meta_net.parameters(), lr=1e-3, weight_decay=args.weight_decay)
    
#####################
# Initial Optimizee #
#####################
    
# Optimizer for original network, just for zeroing gradient and get refined gradient
if optimizer_type == 'SGD-M':
    optimizee = SGD(net.parameters(), lr=args.init_lr,
                    momentum=0.9, weight_decay=5e-4)
elif optimizer_type == 'SGD':
    optimizee = SGD(net.parameters(), lr=args.init_lr)
elif optimizer_type in ['adam', 'Adam']:
    optimizee = Adam(net.parameters(), lr=args.init_lr,
                     weight_decay=5e-4)
else:
    raise NotImplementedError
    
####################
# Initial Recorder #
####################
if args.exp_spec is not '':
    SummaryPath += ('-' + args.exp_spec)

print('Save to %s' %SummaryPath)

if os.path.exists(SummaryPath):
    print('Record exist, remove')
    input()
    shutil.rmtree(SummaryPath)
    os.makedirs(SummaryPath)
else:
    os.makedirs(SummaryPath)

recorder = Recorder(SummaryPath=SummaryPath, dataset_name=dataset_name)

In training of MetaQuant, in order to train meta quantizer, the meta gradient produced by meta quantizer is added to the inference process to get loss:

<img src="figs/MetaQuant-Forward.png">

Therefore, the forward process in base net is modified to incorprate meta gradient embedded using ```meta_grad_dict```.

In [None]:
"""
This block begins training
"""
meta_hidden_state_dict = dict() # Dictionary to store hidden states for all layers for memory-based meta network
meta_grad_dict = dict() # Dictionary to store meta net output: gradient for origin network's weight / bias

for epoch in range(MAX_EPOCH):

    if recorder.stop: break

    print('\nEpoch: %d, lr: %e' % (epoch, optimizee.param_groups[0]['lr']))

    net.train()
    end = time.time()

    recorder.reset_performance()

    for batch_idx, (inputs, targets) in enumerate(train_loader):

        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()

        meta_optimizer.zero_grad()

        # In first iteration of whole training, meta gradient hasn't been generated, 
        # therefore the first forward is conducted without meta gradient.
        if batch_idx == 0 and epoch == 0:
            pass
        # meta gradient used in current iteration is generated by the gradient and weights 
        # from previous iteration.
        else:
            meta_grad_dict, meta_hidden_state_dict = \
                meta_gradient_generation(
                        meta_net, net, meta_method, meta_hidden_state_dict
                )
        # Conduct forward using meta gradient
        outputs = net(inputs, quantized_type=quantized_type,
                      meta_grad_dict=meta_grad_dict,
                      lr=optimizee.param_groups[0]['lr'])

        optimizee.zero_grad()

        # Taking backward generate gradient for meta pruner and base model
        # Non-meta-weights' (bias, BN layer) gradient is attained here
        losses = nn.CrossEntropyLoss()(outputs, targets)
        losses.backward()

        meta_optimizer.step()

        # Assign meta gradient for actual gradients used in update_parameters
        if len(meta_grad_dict) != 0:
            for layer_info in net.layer_name_list:
                layer_name = layer_info[0]
                layer_idx = layer_info[1]
                layer = get_layer(net, layer_idx)
                layer.weight.grad.data = (layer.calibration * layer.pre_quantized_grads)
                # layer.weight.grad.data.copy_(layer.calibration * meta_grad_dict[layer_name][1].data)

        # Get refine gradients for next computation
        optimizee.get_refine_gradient()

        # These gradient should be saved in next iteration's inference
        if len(meta_grad_dict) != 0:
            update_parameters(net, lr=optimizee.param_groups[0]['lr'])

        recorder.update(loss=losses.data.item(), acc=accuracy(outputs.data, targets.data, (1,5)),
                        batch_size=outputs.shape[0], cur_lr=optimizee.param_groups[0]['lr'], end=end)

        recorder.print_training_result(batch_idx, len(train_loader))
        end = time.time()

    test_acc = test(net, quantized_type=quantized_type, test_loader=test_loader,
                    dataset_name=dataset_name, n_batches_used=None)
    recorder.update(loss=None, acc=test_acc, batch_size=0, end=None, is_train=False)

    # Adjust learning rate
    recorder.adjust_lr(optimizer=optimizee, adjust_type=lr_adjust, epoch=epoch)

best_test_acc = recorder.get_best_test_acc()
if type(best_test_acc) == tuple:
    print('Best test top 1 acc: %.3f, top 5 acc: %.3f' % (best_test_acc[0], best_test_acc[1]))
else:
    print('Best test acc: %.3f' %best_test_acc)
recorder.close()