In [None]:
import argparse
import os
import time
import logging
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import models
import torch.distributed as dist
from data import DataRegime
from utils.log import setup_logging, ResultsLog, save_checkpoint
from utils.optim import OptimRegime
from utils.cross_entropy import CrossEntropyLoss
from utils.misc import torch_dtypes
from utils.param_filter import FilterModules, is_bn
from utils.convert_pytcv_model import convert_pytcv_model
from datetime import datetime
from ast import literal_eval
from trainer import Trainer
from utils.adaquant import *
import torchvision
import scipy.optimize as opt
import torch.nn.functional as F
import warnings
import numpy as np
from models.modules.quantize import methods
from tqdm import tqdm
import pandas as pd
import math
import shutil
from models.modules.quantize import QParams
import ast
import ntpath
from functools import partial

In [None]:
global best_prec1, dtype
    acc = -1
    loss = -1
    best_prec1 = 0
    dtype = torch_dtypes.get(args.dtype)
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    time_stamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    if args.evaluate:
        args.results_dir = '/tmp'
    if args.save is '':
        args.save = time_stamp
    save_path = os.path.join(args.results_dir, args.save)

    args.distributed = args.local_rank >= 0 or args.world_size > 1

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_init,
                                world_size=args.world_size, rank=args.local_rank)
        args.local_rank = dist.get_rank()
        args.world_size = dist.get_world_size()
        if args.dist_backend == 'mpi':
            # If using MPI, select all visible devices
            args.device_ids = list(range(torch.cuda.device_count()))
        else:
            args.device_ids = [args.local_rank]

    if not os.path.exists(save_path) and not (args.distributed and args.local_rank > 0):
        os.makedirs(save_path)

    setup_logging(os.path.join(save_path, 'log.txt'),
                  resume=args.resume is not '',
                  dummy=args.distributed and args.local_rank > 0)

    results_path = os.path.join(save_path, 'results')
    results = ResultsLog(
        results_path, title='Training Results - %s' % args.save)

    logging.info("saving to %s", save_path)
    logging.debug("run arguments: %s", args)
    logging.info("creating model %s", args.model)

    if 'cuda' in args.device and torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)
        torch.cuda.set_device(args.device_ids[0])
        cudnn.benchmark = True
    else:
        args.device_ids = None

    # create model
    model = models.__dict__[args.model]
    dataset_type = 'imagenet' if args.dataset =='imagenet_calib' else args.dataset
    model_config = {'dataset': dataset_type}

    if args.model_config is not '':
        if isinstance(args.model_config, dict):
            for k, v in args.model_config.items():
                if k not in model_config.keys():
                    model_config[k] = v
        else:
            args_dict = literal_eval(args.model_config)
            for k, v in args_dict.items():
                model_config[k] = v
    if (args.absorb_bn or args.load_from_vision or args.pretrained) and not args.batch_norn_tuning:
        if args.load_from_vision:
            import torchvision
            exec_lfv_str = 'torchvision.models.' + args.load_from_vision + '(pretrained=True)'
            model = eval(exec_lfv_str)
            if 'pytcv' in args.model:
                from pytorchcv.model_provider import get_model as ptcv_get_model
                exec_lfv_str ='ptcv_get_model("'+ args.load_from_vision +'", pretrained=True)'
                model_pytcv = eval(exec_lfv_str)
                model = convert_pytcv_model(model,model_pytcv)
        else:
            if not os.path.isfile(args.absorb_bn):
                parser.error('invalid checkpoint: {}'.format(args.evaluate))
            model = model(**model_config)
            checkpoint = torch.load(args.absorb_bn,map_location=lambda storage, loc: storage)
            checkpoint = checkpoint['state_dict'] if 'state_dict' in checkpoint.keys() else checkpoint
            model.load_state_dict(checkpoint,strict=False)
        if 'batch_norm' in model_config and not model_config['batch_norm']:
            logging.info('Creating absorb_bn state dict')
            search_absorbe_bn(model)
            filename_ab = args.absorb_bn+'.absorb_bn' if args.absorb_bn else save_path+'/'+args.model+'.absorb_bn'
            torch.save(model.state_dict(),filename_ab)
        else:    
            filename_bn = save_path+'/'+args.model+'.with_bn'
            torch.save(model.state_dict(),filename_bn)
        if (args.load_from_vision or args.absorb_bn) and not args.evaluate_init_configuration: return

    if 'inception' in args.model:
        model = model(init_weights=False, **model_config)
    else:
        model = model(**model_config)
    logging.info("created model with configuration: %s", model_config)
    
    num_parameters = sum([l.nelement() for l in model.parameters()])
    logging.info("number of parameters: %d", num_parameters)

    # optionally resume from a checkpoint
    if args.evaluate:
        if not os.path.isfile(args.evaluate):
            parser.error('invalid checkpoint: {}'.format(args.evaluate))
        checkpoint = torch.load(args.evaluate, map_location="cpu")
        # Overrride configuration with checkpoint info
        args.model = checkpoint.get('model', args.model)
        args.model_config = checkpoint.get('config', args.model_config)
        if not model_config['batch_norm']:
            search_absorbe_fake_bn(model)
        # load checkpoint
        if 'state_dict' in checkpoint.keys():
            model.load_state_dict(checkpoint['state_dict'])
            logging.info("loaded checkpoint '%s'", args.evaluate)
        else:
            model.load_state_dict(checkpoint,strict=False)
            logging.info("loaded checkpoint '%s'",args.evaluate)
          

    if args.resume:
        checkpoint_file = args.resume
        if os.path.isdir(checkpoint_file):
            results.load(os.path.join(checkpoint_file, 'results.csv'))
            checkpoint_file = os.path.join(
                checkpoint_file, 'model_best.pth.tar')
        if os.path.isfile(checkpoint_file):
            logging.info("loading checkpoint '%s'", args.resume)
            checkpoint = torch.load(checkpoint_file)
            if args.start_epoch < 0:  # not explicitly set
                args.start_epoch = checkpoint['epoch'] - 1 if 'epoch' in checkpoint.keys() else 0    
            best_prec1 = checkpoint['best_prec1'] if 'best_prec1' in checkpoint.keys() else -1
            sd = checkpoint['state_dict'] if 'state_dict' in checkpoint.keys() else checkpoint
            model.load_state_dict(sd,strict=False)
            logging.info("loaded checkpoint '%s' (epoch %s)",
                         checkpoint_file, args.start_epoch)
        else:
            logging.error("no checkpoint found at '%s'", args.resume)

    # define loss function (criterion) and optimizer
    loss_params = {}
    if args.label_smoothing > 0:
        loss_params['smooth_eps'] = args.label_smoothing
    criterion = getattr(model, 'criterion', CrossEntropyLoss)(**loss_params)
    if args.kld_loss:
       criterion = nn.KLDivLoss(reduction='mean') 
    criterion.to(args.device, dtype)
    model.to(args.device, dtype)

    # Batch-norm should always be done in float
    if 'half' in args.dtype:
        FilterModules(model, module=is_bn).to(dtype=torch.float)

    # optimizer configuration
    optim_regime = getattr(model, 'regime', [{'epoch': 0,
                                              'optimizer': args.optimizer,
                                              'lr': args.lr,
                                              'momentum': args.momentum,
                                              'weight_decay': args.weight_decay}])
    if args.fine_tune or args.prune: 
        if not args.resume: args.start_epoch=0  
        if args.update_only_th:
            #optim_regime = [
            #    {'epoch': 0, 'optimizer': 'Adam', 'lr': 1e-4}] 
            optim_regime = [
                {'epoch': 0, 'optimizer': 'SGD', 'lr': 1e-1},
                {'epoch': 10, 'lr': 1e-2},
                {'epoch': 15, 'lr': 1e-3}]
        else:              
            optim_regime = [
                {'epoch': 0, 'optimizer': 'SGD', 'lr': 1e-4, 'momentum': 0.9},
                {'epoch': 2, 'lr': 1e-5, 'momentum': 0.9},
                {'epoch': 10, 'lr': 1e-6, 'momentum': 0.9}]
    optimizer = optim_regime if isinstance(optim_regime, OptimRegime) \
        else OptimRegime(model, optim_regime, use_float_copy='half' in args.dtype)

    # Training Data loading code
    
    train_data = DataRegime(getattr(model, 'data_regime', None),
                            defaults={'datasets_path': args.datasets_dir, 'name': args.dataset, 'split': 'train', 'augment': False,
                                      'input_size': args.input_size,  'batch_size': args.batch_size, 'shuffle': not args.seq_adaquant,
                                      'num_workers': args.workers, 'pin_memory': True, 'drop_last': True,
                                      'distributed': args.distributed, 'duplicates': args.duplicates, 'autoaugment': args.autoaugment,
                                      'cutout': {'holes': 1, 'length': 16} if args.cutout else None})
    if args.names_sp_layers is None and args.layers_precision_dict is None:
        args.names_sp_layers =  [key[:-7] for key in model.state_dict().keys() if 'weight' in key and 'running' not in key and ('conv' in key or 'downsample.0' in key or 'fc' in key)]
        if args.keep_first_last: args.names_sp_layers=[name for name in args.names_sp_layers if name!='conv1' and name!='fc' and name != 'Conv2d_1a_3x3.conv']
        args.names_sp_layers = [k for k in args.names_sp_layers if 'downsample' not in k] if args.ignore_downsample else args.names_sp_layers
        if args.num_sp_layers == 0 and not args.keep_first_last:
            args.names_sp_layers = []

    prunner = None 
    trainer = Trainer(model,prunner, criterion, optimizer,
                      device_ids=args.device_ids, device=args.device, dtype=dtype,
                      distributed=args.distributed, local_rank=args.local_rank, mixup=args.mixup, loss_scale=args.loss_scale,
                      grad_clip=args.grad_clip, print_freq=args.print_freq, adapt_grad_norm=args.adapt_grad_norm,epoch=args.start_epoch,update_only_th=args.update_only_th,optimize_rounding=args.optimize_rounding)

    
    # Evaluation Data loading code
    args.eval_batch_size = args.eval_batch_size if args.eval_batch_size > 0 else args.batch_size     
    dataset_type = 'imagenet' if args.dataset =='imagenet_calib' else args.dataset
    val_data = DataRegime(getattr(model, 'data_eval_regime', None),
                          defaults={'datasets_path': args.datasets_dir, 'name': dataset_type, 'split': 'val', 'augment': False,
                                    'input_size': args.input_size, 'batch_size': args.eval_batch_size, 'shuffle': True,
                                    'num_workers': args.workers, 'pin_memory': True, 'drop_last': False})

    if args.evaluate or args.resume:
        from utils.layer_sensativity import search_replace_layer , extract_save_quant_state_dict, search_replace_layer_from_dict
        if args.layers_precision_dict is not None:
            model = search_replace_layer_from_dict(model, ast.literal_eval(args.layers_precision_dict))
        else:
            model = search_replace_layer(model, args.names_sp_layers, num_bits_activation=args.nbits_act,
                                         num_bits_weight=args.nbits_weight)

    cached_input_output = {}
    quant_keys = ['.weight', '.bias', '.equ_scale', '.quantize_input.running_zero_point', '.quantize_input.running_range',
         '.quantize_weight.running_zero_point', '.quantize_weight.running_range','.quantize_input1.running_zero_point', '.quantize_input1.running_range'
         '.quantize_input2.running_zero_point', '.quantize_input2.running_range']        
    if args.adaquant:
        def Qhook(name,module, input, output):
            if module not in cached_qinput:
                cached_qinput[module] = []
            # Meanwhile store data in the RAM.
            cached_qinput[module].append(input[0].detach().cpu())
            # print(name)

        def hook(name,module, input, output):
            if module not in cached_input_output:
                cached_input_output[module] = []
            # Meanwhile store data in the RAM.
            cached_input_output[module].append((input[0].detach().cpu(), output.detach().cpu()))
            # print(name)

        from models.modules.quantize import QConv2d, QLinear
        handlers = []
        count = 0
        for name, m in model.named_modules():
            if isinstance(m, QConv2d) or isinstance(m, QLinear):
            #if isinstance(m, QConv2d) or isinstance(m, QLinear):
            # if isinstance(m, QConv2d):
                m.quantize = False
                if count < 1000:
                # if (isinstance(m, QConv2d) and m.groups == 1) or isinstance(m, QLinear):
                    handlers.append(m.register_forward_hook(partial(hook,name)))
                    count += 1

        # Store input/output for all quantizable layers
        trainer.validate(train_data.get_loader())
        print("Input/outputs cached")

        for handler in handlers:
            handler.remove()

        for m in model.modules():
            if isinstance(m, QConv2d) or isinstance(m, QLinear):
                m.quantize = True

        mse_df = pd.DataFrame(index=np.arange(len(cached_input_output)), columns=['name', 'bit', 'shape', 'mse_before', 'mse_after'])
        print_freq = 100
        for i, layer in enumerate(cached_input_output):
            if i>0 and args.seq_adaquant:
                count = 0
                cached_qinput = {}
                for name, m in model.named_modules():
                    if layer.name==name:
                        if count < 1000:
                            handler= m.register_forward_hook(partial(Qhook,name))
                            count += 1
                # Store input/output for all quantizable layers
                trainer.validate(train_data.get_loader())
                print("cashed quant Input%s"%layer.name)
                cached_input_output[layer][0] = (cached_qinput[layer][0],cached_input_output[layer][0][1])
                handler.remove()            
            print("\nOptimize {}:{} for {} bit of shape {}".format(i, layer.name, layer.num_bits, layer.weight.shape))
            mse_before, mse_after, snr_before, snr_after, kurt_in, kurt_w = \
                optimize_layer(layer, cached_input_output[layer], args.optimize_weights, batch_size=args.batch_size, model_name=args.model)
            print("\nMSE before optimization: {}".format(mse_before))
            print("MSE after optimization:  {}".format(mse_after))
            mse_df.loc[i, 'name'] = layer.name
            mse_df.loc[i, 'bit'] = layer.num_bits
            mse_df.loc[i, 'shape'] = str(layer.weight.shape)
            mse_df.loc[i, 'mse_before'] = mse_before
            mse_df.loc[i, 'mse_after'] = mse_after
            mse_df.loc[i, 'snr_before'] = snr_before
            mse_df.loc[i, 'snr_after'] = snr_after
            mse_df.loc[i, 'kurt_in'] = kurt_in
            mse_df.loc[i, 'kurt_w'] = kurt_w

        mse_csv = args.evaluate + '.mse.csv'
        mse_df.to_csv(mse_csv)

        filename = args.evaluate + '.adaquant'
        torch.save(model.state_dict(), filename)

        train_data = None
        cached_input_output = None
        val_results = trainer.validate(val_data.get_loader())
        logging.info(val_results)