## Libraries & Setup

In [5]:
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
#import models
import torch.distributed as dist
#from data import DataRegime
#from utils.log import setup_logging, ResultsLog
#from utils.optim import OptimRegime
#from utils.cross_entropy import CrossEntropyLoss
from torch.nn import CrossEntropyLoss 
from torch.optim import Adam
#from utils.misc import torch_dtypes
#from utils.param_filter import FilterModules, is_bn
#from utils.convert_pytcv_model import convert_pytcv_model
from datetime import datetime
from ast import literal_eval
#from trainer import Trainer
from utils.adaquant import optimize_layer_adaquant
import numpy as np
import pandas as pd
import ast
from functools import partial
import random
import os
from utils.quantize import QConv2d, QLinear

In [2]:
torch_dtypes = {
    'float': torch.float,
    'float32': torch.float32,
    'float64': torch.float64,
    'double': torch.double,
    'float16': torch.float16,
    'half': torch.half,
    'uint8': torch.uint8,
    'int8': torch.int8,
    'int16': torch.int16,
    'short': torch.short,
    'int32': torch.int32,
    'int': torch.int,
    'int64': torch.int64,
    'long': torch.long
}

In [3]:
acc = -1
loss = -1
best_prec1 = 0
dtype = torch.float32
### SET SEED
seed = 123
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

### LOGGING
#setup_logging(os.path.join(save_path, 'log.txt'),
#                  resume=args.resume is not '',
#                  dummy=args.distributed and args.local_rank > 0)
#results_path = os.path.join(save_path, 'results')
#results = ResultsLog(
#        results_path, title='Training Results - %s' % args.save)
#logging.info("saving to %s", save_path)
#logging.debug("run arguments: %s", args)
#logging.info("creating model %s", args.model)

device_ids = list(range(torch.cuda.device_count()))
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)
    torch.cuda.set_device(device_ids[0])
    cudnn.benchmark = True
else:
    device_ids = None

## Model, Optimizer, Trainer

In [None]:
### CREATE MODEL
model = #?
dataset_type = 'imagenet' 
model_config = {'dataset': dataset_type}

# define loss function (criterion) and optimizer
criterion = CrossEntropyLoss()

# optimizer configuration
optimizer = getattr(model, 'regime', [{'epoch': 0,
                                          'optimizer': Adam() ,
                                          'lr': ,#args.lr,
                                          'momentum': ,#args.momentum,
                                          'weight_decay':  #args.weight_decay
                                          }])

# Training Data loading code
train_data = # ? DataRegime

prunner = None 

# TRAINER?
#trainer = Trainer(model,prunner, criterion, optimizer,
#                      device_ids=args.device_ids, device=args.device, dtype=dtype,
#                      distributed=args.distributed, local_rank=args.local_rank, mixup=args.mixup, loss_scale=args.loss_scale,
#                      grad_clip=args.grad_clip, print_freq=args.print_freq, adapt_grad_norm=args.adapt_grad_norm,epoch=args.start_epoch,update_only_th=args.update_only_th,optimize_rounding=args.optimize_rounding)

    
# Evaluation Data loading code
eval_batch_size = # ?   
dataset_type = 'imagenet' 
val_data =  # ?DataRegime

## Cache, Hook

In [None]:
cached_input_output = {}
quant_keys = ['.weight', '.bias', '.equ_scale', '.quantize_input.running_zero_point', '.quantize_input.running_range',
         '.quantize_weight.running_zero_point', '.quantize_weight.running_range','.quantize_input1.running_zero_point', '.quantize_input1.running_range'
         '.quantize_input2.running_zero_point', '.quantize_input2.running_range'] 
def Qhook(name, module, input, output):
    cached_qinput = {}
    cached_qinput[module] = []
    # Meanwhile store data in the RAM.
    cached_qinput[module].append(input[0].detach().cpu())
    # print(name)

def hook(name,module, input, output):
    if module not in cached_input_output:
        cached_input_output[module] = []
    # Meanwhile store data in the RAM.
    cached_input_output[module].append((input[0].detach().cpu(), output.detach().cpu()))
    # print(name)

In [None]:

handlers = []
count = 0
for name, m in model.named_modules():
    if isinstance(m, QConv2d) or isinstance(m, QLinear):
    #if isinstance(m, QConv2d) or isinstance(m, QLinear):
    # if isinstance(m, QConv2d):
        m.quantize = False
        if count < 1000:
        # if (isinstance(m, QConv2d) and m.groups == 1) or isinstance(m, QLinear):
            handlers.append(m.register_forward_hook(partial(hook,name)))
            count += 1

# Store input/output for all quantizable layers
trainer.validate(train_data.get_loader())
print("Input/outputs cached")

for handler in handlers:
    handler.remove()

for m in model.modules():
    if isinstance(m, QConv2d) or isinstance(m, QLinear):
        m.quantize = True


## Loop Through Layers

In [None]:
mse_df = pd.DataFrame(index=np.arange(len(cached_input_output)), columns=['name', 'bit', 'shape', 'mse_before', 'mse_after'])
print_freq = 100
evaluate = "evaluate"

In [None]:
for i, layer in enumerate(cached_input_output):
    if i>0: # and seq_adaquant = True
        count = 0
        cached_qinput = {}
        for name, m in model.named_modules():
            if layer.name==name:
                if count < 1000:
                    handler= m.register_forward_hook(partial(Qhook,name))
                    count += 1
        # Store input/output for all quantizable layers
        trainer.validate(train_data.get_loader())
        print("cashed quant Input%s"%layer.name)
        cached_input_output[layer][0] = (cached_qinput[layer][0],cached_input_output[layer][0][1])
        handler.remove()            
    print("\nOptimize {}:{} for {} bit of shape {}".format(i, layer.name, layer.num_bits, layer.weight.shape))

    mse_before, mse_after = optimize_layer_adaquant(layer, cached_input_output[layer])
    
    print("\nMSE before optimization: {}".format(mse_before))
    print("MSE after optimization:  {}".format(mse_after))
    mse_df.loc[i, 'name'] = layer.name
    mse_df.loc[i, 'bit'] = layer.num_bits
    mse_df.loc[i, 'shape'] = str(layer.weight.shape)
    mse_df.loc[i, 'mse_before'] = mse_before
    mse_df.loc[i, 'mse_after'] = mse_after


mse_csv = evaluate + '.mse.csv'
mse_df.to_csv(mse_csv)

filename = evaluate + '.adaquant'
torch.save(model.state_dict(), filename)

train_data = None
cached_input_output = None
val_results = trainer.validate(val_data.get_loader())
#logging.info(val_results)