<a href="https://colab.research.google.com/github/monishramadoss/ofa_quant/blob/main/ofa_quant.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install torch==1.8.1 torchvision==0.9.1
!pip install ofa==0.0.4-2012082155
! git clone https://github.com/seshuad/IMagenet
! ls 'IMagenet/tiny-imagenet-200/'

fatal: destination path 'IMagenet' already exists and is not an empty directory.
test  train  val  wnids.txt  words.txt


In [6]:
import time
import skimage.io as nd
import numpy as np
import torch
import skimage.color


path = 'IMagenet/tiny-imagenet-200/'

def get_id_dictionary():
    id_dict = {}
    for i, line in enumerate(open( path + 'wnids.txt', 'r')):
        id_dict[line.replace('\n', '')] = i
    return id_dict
  
def get_class_to_id_dict():
    id_dict = get_id_dictionary()
    all_classes = {}
    result = {}
    for i, line in enumerate(open( path + 'words.txt', 'r')):
        n_id, word = line.split('\t')[:2]
        all_classes[n_id] = word
    for key, value in id_dict.items():
        result[value] = (key, all_classes[key])      
    return result

def get_data(id_dict):
    print('starting loading data')
    train_data, test_data = [], []
    train_labels, test_labels = [], []
    t = time.time()
    for key, value in id_dict.items():
        train_data += [nd.imread( path + 'train/{}/images/{}_{}.JPEG'.format(key, key, str(i)), as_gray=False ) for i in range(500)]
        train_labels_ = np.array([[0]*200]*500)
        train_labels_[:, value] = 1
        train_labels += train_labels_.tolist()

    for line in open( path + 'val/val_annotations.txt'):
        img_name, class_id = line.split('\t')[:2]
        test_data += [nd.imread( path + 'val/images/{}'.format(img_name), as_gray=False)]
        
        # test_data.append(test_data_)
        test_labels_ = np.array([[0]*200])
        test_labels_[0, id_dict[class_id]] = 1
        test_labels += test_labels_.tolist()

    for i in range(len(test_data)):
        if test_data[i].ndim == 2:
            test_data[i] = skimage.color.gray2rgb(test_data[i])

    for i in range(len(train_data)):
        if train_data[i].ndim == 2:
            train_data[i] = skimage.color.gray2rgb(train_data[i])

    print('finished loading data, in {} seconds'.format(time.time() - t))
    return np.array(train_data), np.array(train_labels), np.array(test_data), np.array(test_labels)

def evaluate(model, data, target):
    model.eval()  # set model in eval mode
    total_time = 0
    num_correct = 0  # total 1000
    with torch.no_grad():
        for image, target in zip(data, target):
            # print(data[0].shape)
            start = time.time()
            image = torch.tensor(image)
            target = torch.tensor(target)
            result = model(image)
            total_time += time.time() - start
            
            prediction = idx2label[int(result[0].sort()[1][-1:])]
            if target == prediction:
                num_correct += 1
    
    inference_time = total_time / len(data)
    accuracy = num_correct / len(data)
    return inference_time, accuracy

train_data, train_labels, test_data, test_labels = get_data(get_id_dictionary())

starting loading data
finished loading data, in 56.29645776748657 seconds


In [7]:
import copy
import torch
from ofa.model_zoo import ofa_net
from ofa.imagenet_classification.data_providers.imagenet import ImagenetDataProvider
from ofa.imagenet_classification.run_manager import ImagenetRunConfig, RunManager

raw_resnet = ofa_net('ofa_resnet50', pretrained=True)
config = 'fbgemm'

In [8]:
from ofa.imagenet_classification.networks import ResNets
from ofa.utils.layers import IdentityLayer, ResidualBlock

def val2list(val, repeat_time=1):
    if isinstance(val, list) or isinstance(val, np.ndarray):
        return val
    elif isinstance(val, tuple):
        return list(val)
    else:
        return [val for _ in range(repeat_time)]

def set_active_subnet(ofa, d=None, e=None, w=None, **kwargs):
    depth = val2list(d, len(ofa.BASE_DEPTH_LIST) + 1)
    expand_ratio = val2list(e, len(ofa.blocks))
    width_mult = val2list(w, len(ofa.BASE_DEPTH_LIST) + 2)
    for block, e in zip(ofa.blocks, expand_ratio):
        if e is not None:
            block.active_expand_ratio = e

    if width_mult[0] is not None:
        ofa.input_stem[1].conv.active_out_channel = ofa.input_stem[0].active_out_channel = \
            ofa.input_stem[0].out_channel_list[width_mult[0]]
    if width_mult[1] is not None:
        ofa.input_stem[2].active_out_channel = ofa.input_stem[2].out_channel_list[width_mult[1]]

    if depth[0] is not None:
        ofa.input_stem_skipping = (depth[0] != max(ofa.depth_list))
    for stage_id, (block_idx, d, w) in enumerate(zip(ofa.grouped_block_index, depth[1:], width_mult[2:])):
        if d is not None:
            ofa.runtime_depth[stage_id] = max(ofa.depth_list) - d
        if w is not None:
            for idx in block_idx:
                ofa.blocks[idx].active_out_channel = ofa.blocks[idx].out_channel_list[w]

def set_max_subnet(ofa):
    set_active_subnet(ofa, max(ofa.depth_list), max(ofa.expand_ratio_list), len(ofa.width_mult_list) - 1)

def get_active_subnet(ofa, preserve_weight=True):
    input_stem = [ofa.input_stem[0].get_active_subnet(3, preserve_weight)]
    active_out = ofa.input_stem[0].active_out_channel
    input_stem_blocks = [(0, active_out)]
    
    if ofa.input_stem_skipping <= 0:        
        input_stem.append(ResidualBlock(
            ofa.input_stem[1].conv.get_active_subnet(active_out, preserve_weight),
            IdentityLayer(active_out, active_out)
        ))
        input_stem_blocks += [(1, active_out)]
    input_stem.append(ofa.input_stem[2].get_active_subnet(active_out, preserve_weight))
    input_channel = ofa.input_stem[2].active_out_channel
    input_stem_blocks += [(2, input_channel)]
 
    blocks = []
    block_groups = []
    block_input_channel = {}
    for stage_id, block_idx in enumerate(ofa.grouped_block_index):
        depth_param = ofa.runtime_depth[stage_id]
        active_idx = block_idx[:len(block_idx) - depth_param]
        block_groups+= active_idx
        for idx in active_idx:
            block_input_channel[idx] = input_channel
            blocks.append(ofa.blocks[idx].get_active_subnet(input_channel, preserve_weight))            
            input_channel = ofa.blocks[idx].active_out_channel
                 
    classifier = ofa.classifier.get_active_subnet(input_channel, preserve_weight)

    subnet = ResNets(input_stem, blocks, classifier)
    subnet.set_bn_param(**ofa.get_bn_param())
    subnet.input_stem_blocks = dict(input_stem_blocks)
    subnet.block_groups = block_groups
    subnet.block_input_channel = block_input_channel
    return subnet


In [9]:
import torch.nn as nn

def nested_children(m: torch.nn.Module):
    children = dict(m.named_children())
    output = {}
    if children == {}:
        return m
    else:
        for name, child in children.items():
            try:
                output[name] = nested_children(child)
            except TypeError:
                output[name] = nested_children(child)
    return output

def squash_nested_dict(nested_dict, ret_lst={}, prefix='', mod=None):
    if nested_dict == {}:
        ret_lst[prefix[1:]] = mod
        return
    for k in nested_dict.keys():
        if isinstance(nested_dict[k], dict):
            squash_nested_dict(nested_dict[k], ret_lst, prefix+'.'+k)
        else:
            squash_nested_dict({}, ret_lst, prefix+'.'+k, nested_dict[k])

def remap_weight_names(mappings, dict1):
    input_stem = mappings['input_stem']
    block = mappings['blocks']
    map = dict1.copy()
    tmp = {}
    for i, (input_id, channel) in enumerate(input_stem.items()):
        tmp[str(input_id)] = map['input_stem'][str(i)]
    
    map['input_stem'] = tmp
    tmp = {}
    for i in map['blocks']:
        idx = block[int(i)]
        tmp[str(idx)] = map['blocks'][i]
    
    map['blocks'] = tmp
    return map

class Quant_Model(nn.Module):
    def __init__(self, quat_model, float_model):
        super(Quant_Model, self).__init__()
        self.input_zero_points = {}
        self.input_scales = {}
        self.float_model = float_model
        self.quant_model = quat_model
        self.quant_state_dict = quat_model.state_dict()

        layer_names = {}
        _layers = nested_children(self.float_model)
        squash_nested_dict(_layers, layer_names)
        
        _quant_layers = nested_children(self.quant_model)
        
        _remap = remap_weight_names({
            'input_stem':self.float_model.input_stem_blocks,
            'blocks': self.float_model.block_groups,
        }, _layers)

        _remap_names = {}
        squash_nested_dict(_remap, _remap_names)
        
        # for l1, l2 in zip(layer_names.keys(), _remap_names.keys()):
        #     print(l1, l2)
        
        self._remapped_layer_names = dict(zip(layer_names.keys(), _remap_names.keys()))

        for i, l in enumerate(list(layer_names.keys())):
            layer_names[l].register_forward_pre_hook(self.forward_pre_hook(l))


    def forward_pre_hook(self, layer_name):
        def pre_hook(module, x):
            if isinstance(module, nn.Conv2d) or isinstance(module, nn.Identity):
                with torch.no_grad():
                    quant_model_layer = self._remapped_layer_names[layer_name]
                    zero_point = self.quant_state_dict[quant_model_layer+'.zero_point'].float()
                    scale = self.quant_state_dict[quant_model_layer+'.scale'].float()
                    quant_min = 0.0
                    quant_max = 1.0
                    tmp = torch.clamp(torch.round(torch.div(x[0], scale) + zero_point), quant_min, quant_max) - zero_point
                    x = tmp*scale
            return x
        return pre_hook

    def forward(self, x):
        return self.float_model(x)



In [10]:

set_max_subnet(raw_resnet)
max_subnet = get_active_subnet(raw_resnet)
max_subnet.eval()
max_subnet.qconfig = torch.quantization.get_default_qconfig(config)

set_active_subnet(raw_resnet, 0, 0, 0)
min_subnet = get_active_subnet(raw_resnet)
min_subnet.eval()
min_subnet.qconfig = torch.quantization.get_default_qconfig(config)

input_fp32 = torch.randn(1, 3, 224, 224, dtype=torch.float32)

min_subnet_1 = torch.quantization.prepare(min_subnet)
_ = min_subnet_1(input_fp32)
quat_min_model = torch.quantization.convert(min_subnet_1)

max_subnet_1 = torch.quantization.prepare(max_subnet)
_ = max_subnet_1(input_fp32)
quat_max_model = torch.quantization.convert(max_subnet_1)

max_subnet.load_state_dict(torch.load('./large_subnet.pth'))

fake_quant_model = Quant_Model(quat_max_model, min_subnet)
fake_output = fake_quant_model(input_fp32)




  reduce_range will be deprecated in a future release of PyTorch."


In [12]:
ImagenetDataProvider.DEFAULT_PATH = path
run_config = ImagenetRunConfig(test_batch_size=16, n_worker=20)
run_config.data_provider.assign_active_img_size(224)

run_manager = RunManager('./tmp/eval_subnet', quat_min_model, run_config, init=False)
run_manager.reset_running_statistics(net=min_subnet) 


Color jitter: tf, resize_scale: 0.08, img_size: 224


AttributeError: ignored

88
