In [1]:
import os
import gc
import cv2
import copy
import time
import random

# For data manipulation
import numpy as np
import pandas as pd

# Pytorch Imports
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms

#Pytorch Quantization
import torch.quantization

# Utils
from tqdm import tqdm
from collections import defaultdict

# Model Import
from ResNet20 import resnet20

In [2]:
CONFIG = dict(
    seed = 42,
    train_batch_size = 128,
    valid_batch_size = 256,
    num_calibration_batches = 32,
    num_classes = 10,
    device = torch.device("cpu"),
    bits = 8.0
)

In [3]:
MODEL_PATHS = 'ResNet20 final.bin'

In [4]:
def set_seed(seed = 42):
    '''Sets the seed of the entire notebook so results are the same every time we run.
    This is for REPRODUCIBILITY.'''
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ['PYTHONHASHSEED'] = str(seed)
    
set_seed(CONFIG['seed'])

In [5]:
def criterion(outputs, labels):
    loss = nn.CrossEntropyLoss()
    return loss(outputs, labels)

In [6]:
@torch.no_grad()
def valid_fn(model, dataloader, device, neval_batches):
    model.eval()
    
    dataset_size = 0
    running_loss = 0.0
    correct = 0.0
    PREDS = []
    count = 0
    
    bar = tqdm(enumerate(dataloader), total=len(dataloader), ncols=100)
    for step, data in bar:        
        inputs, targets = data
        inputs = inputs.to(device)
        targets = targets.to(device)
            
        # рачсет вывода
        output = model(inputs)
        loss = criterion(output, targets)
        
        _, preds = output.max(1)
        correct += preds.eq(targets).sum()

        batch_size = inputs.size(0)
        
        count += 1
        
        running_loss += (loss.item() * batch_size)
        dataset_size += batch_size
        
        sum_loss = running_loss / dataset_size
        
        sum_score = correct.cpu().detach().numpy() / dataset_size
        
        bar.set_postfix({'Valid_Loss':sum_loss, 'Valid_Score':sum_score})
        PREDS.append(output.view(-1).cpu().detach().numpy()) 
        if count >= neval_batches:
            PREDS = np.concatenate(PREDS)
            return sum_loss, sum_score, PREDS
        
    PREDS = np.concatenate(PREDS)
    
    gc.collect()
    return sum_loss, sum_score, PREDS

In [7]:
def uniform_symmetric_quantizer(x, bits=8.0, minv=None, maxv=None, signed=True, 
                                scale_bits=0.0, num_levels=None, scale=None, simulated=True):
    if minv is None:
        maxv = torch.max(torch.abs(x))
        minv = - maxv if signed else 0

    if signed:
        maxv = np.max([-float(minv), float(maxv)])
        minv = - maxv 
    else:
        minv = 0
    
    if num_levels is None:
        num_levels = 2 ** bits

    if scale is None:
        scale = (maxv - minv) / (num_levels - 1)

    if scale_bits > 0:
        scale_levels = 2 ** scale_bits
        scale = torch.round(torch.mul(scale, scale_levels)) / scale_levels
            
    ## clamp
    x = torch.clamp(x, min=float(minv), max=float(maxv))
        
    x_int = torch.round(x / scale)
    
    if signed:
        x_quant = torch.clamp(x_int, min=-num_levels/2, max=num_levels/2 - 1)
        assert(minv == - maxv)
    else:
        x_quant = torch.clamp(x_int, min=0, max=num_levels - 1)
        assert(minv == 0 and maxv > 0)
        
    x_dequant = x_quant * scale
    
    return x_dequant if simulated else x_quant


In [8]:
def quant_weights(w):
    '''
    Квантизация весов слоя 
    '''
    
    # uniform symmetric quantization
    qw = uniform_symmetric_quantizer(w, bits=CONFIG['bits'])

    err = float(torch.sum(torch.mul(qw - w, qw - w)))

    return qw, err

In [9]:
def quant_checkpoint(checkpoint):
    '''
    Квантизация слоев
    '''
    bits = CONFIG['bits']

    print('quantizing weights into %s bits, %s layers' % (bits, len(checkpoint.keys())))

    all_quant_error, all_quant_num = 0, 0
    for each_layer in checkpoint.keys():
        
        if '.num_batches_tracked' in each_layer or '.minv' in each_layer or '.maxv' in each_layer or 'bn' in each_layer or '.downsample' in each_layer or 'fc.bias' in each_layer :
            continue
        
        each_layer_weights = checkpoint[each_layer].clone()

        print('quantize for: %s, size: %s' % (each_layer, each_layer_weights.size()))
        print('weights range: (%.4f, %.4f)' % 
                            (torch.min(each_layer_weights), torch.max(each_layer_weights)))

        quant_error, quant_num = 0, 0
        output_channel_num = each_layer_weights.size()[0]
        # channel-wise quant for each output channel
        for c in range(output_channel_num):  
            w = each_layer_weights[c, :].clone()
            #w = each_layer_weights.clone()
            qw, err = quant_weights(w)

            each_layer_weights[c, :] = qw
            #each_layer_weights = qw
            quant_error += err
            quant_num += len(qw.reshape(-1, 1))

        all_quant_num += quant_num
        all_quant_error += quant_error

        checkpoint[each_layer] = each_layer_weights
        print('layer quant RMSE: %.4e' % np.sqrt(quant_error / quant_num))
        
    rmse = np.sqrt(all_quant_error / all_quant_num)
    print('\ntotal quant RMSE: %.4e' % rmse)

    return checkpoint, rmse

In [10]:
class QuantActivations(nn.Module):
    '''
    Квантизация активаций:
    (1) the input of conv layer
    (2) the input of linear fc layer
    (3) the input of pooling layer
    '''
    def __init__(self, act_bits, get_stats, minv=None, maxv=None, 
        calibrate_sample_size=512, calibrate_batch_size=4, topk=10):
        '''
        calibrate_sample_size: calibration sample size, typically from random training data
        calibrate_batch_size: calibration sampling batch size
        topk: calibrate topk lower and upper bounds
        '''
        super(QuantActivations, self).__init__()
        self.act_bits = act_bits
        self.get_stats = get_stats
        self.index = 0
        self.topk = topk
        self.sample_batches = calibrate_sample_size // calibrate_batch_size
        stats_size = (self.sample_batches, self.topk) if self.get_stats else 1
        
        self.register_buffer('minv', torch.zeros(stats_size))
        self.register_buffer('maxv', torch.zeros(stats_size))

    def forward(self, x):
        if self.get_stats:
            y = x.clone()
            y = torch.reshape(y, (-1,))
            y, indices = torch.sort(y)
            topk_mins = y[:self.topk]
            topk_maxs = y[-self.topk:]
            if self.index < self.sample_batches:
                self.minv[self.index, :] = topk_mins
                self.maxv[self.index, :] = topk_maxs
                self.index += 1

        if self.act_bits > 0:
            ## uniform quantization
            if self.minv is not None:
                if self.minv >= 0.0: # activation after relu
                    self.minv *= 0.0
                    self.signed = False
                else: 
                    self.maxv = max(-self.minv, self.maxv) 
                    self.minv = - self.maxv
                    self.signed = True
            x = uniform_symmetric_quantizer(x, bits=self.act_bits, 
                    minv=self.minv, maxv=self.maxv, signed=self.signed)
        return x


def quant_model_acts(model, act_bits, get_stats, calibrate_batch_size=4):
    """
    Добавление активаций к слоям
    """
    if type(model) in [nn.Conv2d, nn.Linear, nn.AdaptiveAvgPool2d]:
        quant_act = QuantActivations(act_bits, get_stats, calibrate_batch_size=calibrate_batch_size)
        return nn.Sequential(quant_act, model)
    elif type(model) == nn.Sequential:
        modules = []
        for name, module in model.named_children():
            modules.append(quant_model_acts(module, act_bits, get_stats, calibrate_batch_size=calibrate_batch_size))
        return nn.Sequential(*modules)
    else:
        quantized_model = copy.deepcopy(model)
        for attribute in dir(model):
            module = getattr(model, attribute)
            if isinstance(module, nn.Module):
                setattr(quantized_model, attribute, 
                    quant_model_acts(module, act_bits, get_stats, calibrate_batch_size=calibrate_batch_size))
        return quantized_model

In [11]:
def save_model_act_stats(model, save_path):
    checkpoint = model.state_dict()
    act_stats = copy.deepcopy(checkpoint)
    for key in checkpoint:
        if '.minv' not in key and '.maxv' not in key:
            del act_stats[key]
    torch.save(act_stats, save_path)
    return act_stats

In [12]:
quantized_model1 = resnet20()
quantized_model1.to(CONFIG['device'])
quantized_model1.load_state_dict(torch.load(MODEL_PATHS))
checkpoint = quantized_model1.state_dict()

In [13]:
#get activation stats
quantized_model1 = quant_model_acts(quantized_model1, 0, True, CONFIG['num_calibration_batches'])

In [14]:
train_loader = torch.utils.data.DataLoader(
        torchvision.datasets.CIFAR10(root='./data', train=True, transform=transforms.Compose([
            transforms.Pad(4),
            transforms.RandomCrop(32),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]), download=True),
        batch_size=CONFIG['train_batch_size'], shuffle=True,
        num_workers=2)

Files already downloaded and verified


In [15]:
valid_fn(quantized_model1, train_loader, CONFIG['device'], CONFIG['num_calibration_batches'])

  8%|█▉                      | 31/391 [00:43<08:22,  1.40s/it, Valid_Loss=0.0744, Valid_Score=0.976]


(0.07441426627337933,
 0.97607421875,
 array([-5.9480247, -9.374774 ,  1.3879923, ..., -3.714039 , -0.9983002,
         4.2985835], dtype=float32))

In [16]:
# save the activation stats
os.makedirs('stats/', exist_ok=True)
act_stats_save_path = 'stats/%s_act_stats.pth' % "ResNet20"
save_model_act_stats(quantized_model1, act_stats_save_path)

OrderedDict([('conv.0.minv',
              tensor([[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
                      [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
                      [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
                      [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
                      [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
                      [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
                      [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
                      [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
                      [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
                      [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
                      [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
                      [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
                      [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
         

### Quantization

In [17]:
def act_clip_bounds(stats, act_clip_method, min_or_max):
    if act_clip_method.startswith('top'):
        topk = int(act_clip_method.split('_')[1])
        assert(topk <= 20)
        stats = stats[:, :topk] if min_or_max == 'min' else stats[:, -topk:]
        values, indices = torch.median(stats, 1)
        return torch.mean(values)
    else:
        raise RuntimeError("Please implement for activation clip method: %s !!!" % act_clip_method) 

In [18]:
def load_model_act_stats(model, load_path, act_clip_method):
    checkpoint = model.state_dict()
    act_stats = torch.load(load_path)
    for key in act_stats:
        min_or_max = 'min' if '.minv' in key else 'max'
        value = act_clip_bounds(act_stats[key], act_clip_method, min_or_max)
        key = key.replace('module.', '')
        checkpoint[key][0] = value
    model.load_state_dict(checkpoint)
    return model

In [19]:
quantized_model2 = resnet20()
quantized_model2.to(CONFIG['device'])
quantized_model2.load_state_dict(torch.load(MODEL_PATHS))
checkpoint = quantized_model2.state_dict()

In [20]:
# quantize weights
rmse = 0
checkpoint, rmse = quant_checkpoint(checkpoint)
# load the updated weights
quantized_model2.load_state_dict(checkpoint)
del checkpoint

quantizing weights into 8.0 bits, 128 layers
quantize for: conv.weight, size: torch.Size([16, 3, 3, 3])
weights range: (-1.7392, 2.0936)
layer quant RMSE: 2.8653e-03
quantize for: layer1.0.conv1.weight, size: torch.Size([16, 16, 3, 3])
weights range: (-0.9321, 1.2331)
layer quant RMSE: 1.6528e-03
quantize for: layer1.0.conv2.weight, size: torch.Size([16, 16, 3, 3])
weights range: (-0.8008, 0.9115)
layer quant RMSE: 1.4722e-03
quantize for: layer1.1.conv1.weight, size: torch.Size([16, 16, 3, 3])
weights range: (-0.7769, 0.7389)
layer quant RMSE: 1.3599e-03
quantize for: layer1.1.conv2.weight, size: torch.Size([16, 16, 3, 3])
weights range: (-0.7376, 0.6565)
layer quant RMSE: 1.1594e-03
quantize for: layer1.2.conv1.weight, size: torch.Size([16, 16, 3, 3])
weights range: (-1.0651, 0.7439)
layer quant RMSE: 1.4312e-03
quantize for: layer1.2.conv2.weight, size: torch.Size([16, 16, 3, 3])
weights range: (-0.7821, 0.9201)
layer quant RMSE: 1.2518e-03
quantize for: layer2.0.conv1.weight, size:

In [21]:
quantized_model2 = quant_model_acts(quantized_model2, CONFIG['bits'], False, CONFIG['num_calibration_batches'])

In [22]:
quantized_model2 = load_model_act_stats(quantized_model2, act_stats_save_path, 'top_10')

In [23]:
validation_loader = torch.utils.data.DataLoader(
        torchvision.datasets.CIFAR10(root='./data', train=False, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])),
        batch_size=CONFIG['valid_batch_size'], shuffle=False,
        num_workers=2)

In [24]:
def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p")/1e6)
    os.remove('temp.p')

In [25]:
def performance_inference(model, dataloader, device):
    
    if torch.cuda.is_available():
        print("[INFO] Using GPU: {}\n".format(torch.cuda.get_device_name()))
    
    start = time.time()
    history = defaultdict(list)
    
    start = time.time()
        
    val_loss, val_score, preds = valid_fn(model, dataloader, CONFIG['device'], CONFIG['valid_batch_size'])
    
    end = time.time()
    
    history['Valid Loss'].append(val_loss)
    history['Valid Score'].append(val_score)
    
    time_elapsed = end - start
    print('Validation complete in {:.0f}ms'.format(
        time_elapsed * 1000))
    print("Validation Loss: {:.4f}".format(val_loss))
    print("Validation Score: {:.4f}".format(val_score))
    
    
    return model, history

In [26]:
print("Size of model after quantization")
print_size_of_model(quantized_model2)

Size of model after quantization
Size (MB): 1.232297


In [27]:
quantized_model2, history = performance_inference(quantized_model2, validation_loader, CONFIG['device'])

[INFO] Using GPU: NVIDIA GeForce RTX 2080



100%|██████████████████████████| 40/40 [00:07<00:00,  5.06it/s, Valid_Loss=0.265, Valid_Score=0.919]

Validation complete in 7995ms
Validation Loss: 0.2648
Validation Score: 0.9186



