<a href="https://colab.research.google.com/github/monishramadoss/ofa_quant/blob/main/ofa_quant.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [8]:
!pip install ofa
!pip install torch==1.4.0 torchvision==0.5.0
! git clone https://github.com/seshuad/IMagenet
! ls 'IMagenet/tiny-imagenet-200/'

fatal: destination path 'IMagenet' already exists and is not an empty directory.
test  train  val  wnids.txt  words.txt


In [9]:
import time
import skimage.io as nd
import numpy as np

path = 'IMagenet/tiny-imagenet-200/'

def get_id_dictionary():
    id_dict = {}
    for i, line in enumerate(open( path + 'wnids.txt', 'r')):
        id_dict[line.replace('\n', '')] = i
    return id_dict
  
def get_class_to_id_dict():
    id_dict = get_id_dictionary()
    all_classes = {}
    result = {}
    for i, line in enumerate(open( path + 'words.txt', 'r')):
        n_id, word = line.split('\t')[:2]
        all_classes[n_id] = word
    for key, value in id_dict.items():
        result[value] = (key, all_classes[key])      
    return result

def get_data(id_dict):
    print('starting loading data')
    train_data, test_data = [], []
    train_labels, test_labels = [], []
    t = time.time()
    for key, value in id_dict.items():
        train_data += [nd.imread( path + 'train/{}/images/{}_{}.JPEG'.format(key, key, str(i)), ) for i in range(500)]
        train_labels_ = np.array([[0]*200]*500)
        train_labels_[:, value] = 1
        train_labels += train_labels_.tolist()

    for line in open( path + 'val/val_annotations.txt'):
        img_name, class_id = line.split('\t')[:2]
        test_data = nd.imread( path + 'val/images/{}'.format(img_name))
        test_data.append()
        test_labels_ = np.array([[0]*200])
        test_labels_[0, id_dict[class_id]] = 1
        test_labels += test_labels_.tolist()

    print('finished loading data, in {} seconds'.format(time.time() - t))
    test = np.array(test_data), np.array(test_labels)  
    return [np.array(train_data), np.array(train_labels), *test]

# train_data, train_labels, test_data, test_labels = get_data(get_id_dictionary())

In [10]:
import copy
import torch
from ofa.model_zoo import ofa_net
from ofa.imagenet_classification.data_providers.imagenet import ImagenetDataProvider
from ofa.imagenet_classification.run_manager import ImagenetRunConfig, RunManager

raw_resnet = ofa_net('ofa_resnet50', pretrained=True)
raw_resnet.set_max_net()
max_subnet = raw_resnet.get_active_subnet(preserve_weight=True)
raw_resnet.set_active_subnet(d=0, w=0, e=0)
min_subnet = raw_resnet.get_active_subnet(preserve_weight=True)


ImagenetDataProvider.DEFAULT_PATH = path
run_config = ImagenetRunConfig(test_batch_size=16, n_worker=20)
run_config.data_provider.assign_active_img_size(224)


run_manager = RunManager('./tmp/eval_subnet', min_subnet, run_config, init=False, )
run_manager.reset_running_statistics(net=min_subnet)

Color jitter: tf, resize_scale: 0.08, img_size: 224
ResNets(
  (input_stem): ModuleList(
    (0): ConvLayer(
      (conv): Conv2d(3, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): ReLU(inplace=True)
    )
    (1): ConvLayer(
      (conv): Conv2d(24, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): ReLU(inplace=True)
    )
  )
  (max_pooling): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (blocks): ModuleList(
    (0): ResNetBottleneckBlock(
      (conv1): Sequential(
        (conv): Conv2d(40, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): ReLU(inplace=True)
      )
      (conv2): Sequential(
        (c

In [11]:
loss, (top1, top5) = run_manager.validate(net=max_subnet)
print('Results: loss=%.5f,\t top1=%.1f,\t top5=%.1f' % (loss, top1, top5))
loss, (top1, top5) = run_manager.validate(net=min_subnet)
print('Results: loss=%.5f,\t top1=%.1f,\t top5=%.1f' % (loss, top1, top5))

Validate Epoch #1 : 100%|██████████| 625/625 [02:08<00:00,  4.86it/s, loss=8.46, top1=0.06, top5=0.42, img_size=224]


Results: loss=8.45929,	 top1=0.1,	 top5=0.4


Validate Epoch #1 : 100%|██████████| 625/625 [00:37<00:00, 16.79it/s, loss=7.03, top1=0.01, top5=0.2, img_size=224]

Results: loss=7.02851,	 top1=0.0,	 top5=0.2





In [12]:
from ofa.imagenet_classification.networks import ResNets
from ofa.utils.layers import IdentityLayer, ResidualBlock

def val2list(val, repeat_time=1):
    if isinstance(val, list) or isinstance(val, np.ndarray):
        return val
    elif isinstance(val, tuple):
        return list(val)
    else:
        return [val for _ in range(repeat_time)]

def set_active_subnet(ofa, d=None, e=None, w=None, **kwargs):
    depth = val2list(d, len(ofa.BASE_DEPTH_LIST) + 1)
    expand_ratio = val2list(e, len(ofa.blocks))
    width_mult = val2list(w, len(ofa.BASE_DEPTH_LIST) + 2)
    for block, e in zip(ofa.blocks, expand_ratio):
        if e is not None:
            block.active_expand_ratio = e

    if width_mult[0] is not None:
        ofa.input_stem[1].conv.active_out_channel = ofa.input_stem[0].active_out_channel = \
            ofa.input_stem[0].out_channel_list[width_mult[0]]
    if width_mult[1] is not None:
        ofa.input_stem[2].active_out_channel = ofa.input_stem[2].out_channel_list[width_mult[1]]

    if depth[0] is not None:
        ofa.input_stem_skipping = (depth[0] != max(ofa.depth_list))
    for stage_id, (block_idx, d, w) in enumerate(zip(ofa.grouped_block_index, depth[1:], width_mult[2:])):
        if d is not None:
            ofa.runtime_depth[stage_id] = max(ofa.depth_list) - d
        if w is not None:
            for idx in block_idx:
                ofa.blocks[idx].active_out_channel = ofa.blocks[idx].out_channel_list[w]

def set_max_subnet(ofa):
    set_active_subnet(ofa, max(ofa.depth_list), max(ofa.expand_ratio_list), len(ofa.width_mult_list) - 1)

def get_active_subnet(ofa, preserve_weight=True):
    input_stem = [ofa.input_stem[0].get_active_subnet(3, preserve_weight)]
    active_out = ofa.input_stem[0].active_out_channel
    input_stem_blocks = [0]

    if ofa.input_stem_skipping <= 0:        
        input_stem.append(ResidualBlock(
            ofa.input_stem[1].conv.get_active_subnet(active_out, preserve_weight),
            IdentityLayer(active_out, active_out)
        ))
        input_stem_blocks += [1]
    input_stem_blocks += [2]
    input_stem.append(ofa.input_stem[2].get_active_subnet(active_out, preserve_weight))
    input_channel = ofa.input_stem[2].active_out_channel

    blocks = []
    block_groups = []
    for stage_id, block_idx in enumerate(ofa.grouped_block_index):
        depth_param = ofa.runtime_depth[stage_id]
        active_idx = block_idx[:len(block_idx) - depth_param]
        block_groups.append(active_idx)
        for idx in active_idx:
            blocks.append(ofa.blocks[idx].get_active_subnet(input_channel, preserve_weight))
            input_channel = ofa.blocks[idx].active_out_channel
    classifier = ofa.classifier.get_active_subnet(input_channel, preserve_weight)

    subnet = ResNets(input_stem, blocks, classifier)
    subnet.set_bn_param(**ofa.get_bn_param())
    subnet.input_stem_blocks = input_stem_blocks
    subnet.block_groups = block_groups
    return subnet


In [14]:
set_active_subnet(raw_resnet, 0, 0, 0)
min_subnet = get_active_subnet(raw_resnet)

set_max_subnet(raw_resnet)
max_subnet = get_active_subnet(raw_resnet)

max_subnet.eval()
config = 'fbgemm'
max_subnet.qconfig = torch.quantization.get_default_qconfig(config)
max_subnet_1 = torch.quantization.prepare(max_subnet)
quat_max_model = torch.quantization.convert(max_subnet_1).to('cuda:0')

min_subnet.eval()
min_subnet.qconfig = torch.quantization.get_default_qconfig(config)
min_subnet_1 = torch.quantization.prepare(min_subnet)
quat_min_model = torch.quantization.convert(min_subnet_1).to('cuda:0')


In [None]:
loss, (top1, top5) = run_manager.validate(net=quat_max_model)
print('Results: loss=%.5f,\t top1=%.1f,\t top5=%.1f' % (loss, top1, top5))
loss, (top1, top5) = run_manager.validate(net=quat_min_model)
print('Results: loss=%.5f,\t top1=%.1f,\t top5=%.1f' % (loss, top1, top5))