In [2]:
import os.path
import logging
import time
import torch
import torch.nn.utils.prune as prune
import torchvision.transforms.functional as TVF
import numpy as np

from torch.autograd import Variable
from PIL import Image
from collections import OrderedDict
from pathlib import Path
from tqdm import tqdm

from MSRResNet.utils import utils_logger
from MSRResNet.utils import utils_image as util
from MSRResNet.SRResNet import MSRResNet, PrunedMSRResNet
from CARN import CARN
from rdn import RDN


def get_logger(name):
    # Set logger
    utils_logger.logger_info(name, log_path=f'{name}.log')
    logger = logging.getLogger(name)
    logger.setLevel(logging.DEBUG)
    return logger

def count_num_of_parameters(model):
    number_parameters = sum(map(lambda x: x.numel(), model.parameters()))
    return number_parameters

def _load_img_array(path, color_mode='RGB',
                    channel_mean=None, modcrop=[0, 0, 0, 0]):
    '''Load an image using PIL and convert it into specified color space,
    and return it as an numpy array.
    https://github.com/fchollet/keras/blob/master/keras/preprocessing/image.py
    The code is modified from Keras.preprocessing.image.load_img, img_to_array.
    '''
    # Load image
    img = Image.open(path)
    if color_mode == 'RGB':
        cimg = img.convert('RGB')
        x = np.asarray(cimg, dtype='float32')
    elif color_mode == 'YCbCr' or color_mode == 'Y':
        cimg = img.convert('YCbCr')
        x = np.asarray(cimg, dtype='float32')
        if color_mode == 'Y':
            x = x[:, :, 0:1]
    # Normalize To 0-1
    x *= 1.0 / 255.0
    if channel_mean:
        x[:, :, 0] -= channel_mean[0]
        x[:, :, 1] -= channel_mean[1]
        x[:, :, 2] -= channel_mean[2]
    if modcrop[0] * modcrop[1] * modcrop[2] * modcrop[3]:
        x = x[modcrop[0]:-modcrop[1], modcrop[2]:-modcrop[3], :]
    return x

def _rgb2ycbcr(img, maxVal=255):
    # Same as MATLAB's rgb2ycbcr
    # Updated at 03/14/2017
    # Not tested for cb and cr
    O = np.array([[16],
                  [128],
                  [128]])
    T = np.array([[0.256788235294118, 0.504129411764706, 0.097905882352941],
                  [-0.148223529411765, -0.290992156862745, 0.439215686274510],
                  [0.439215686274510, -0.367788235294118, -0.071427450980392]])
    if maxVal == 1:
        O = O / 255.0
    t = np.reshape(img, (img.shape[0] * img.shape[1], img.shape[2]))
    t = np.dot(t, np.transpose(T))
    t[:, 0] += O[0]
    t[:, 1] += O[1]
    t[:, 2] += O[2]
    ycbcr = np.reshape(t, [img.shape[0], img.shape[1], img.shape[2]])
    return ycbcr

def psnr(y_true, y_pred, shave_border=4):
    '''
        Input must be 0-255, 2D
    '''
    target_data = np.array(y_true, dtype=np.float32)
    ref_data = np.array(y_pred, dtype=np.float32)
    diff = ref_data - target_data
    if shave_border > 0:
        diff = diff[shave_border:-shave_border, shave_border:-shave_border]
    rmse = np.sqrt(np.mean(np.power(diff, 2)))
    return 20 * np.log10(255. / rmse)

def calculate_psnr(testsets_H, testset_O, logger):
    count = 0
    avg_psnr_value = 0.0
    h_list = sorted(list(Path(testsets_H).glob('*.png')))
    o_list = sorted(list(Path(testset_O).glob('*.png')))
    for h_path, o_path in tqdm(zip(h_list, o_list)):
        with torch.no_grad():
            # Load label image
            h_img = _load_img_array(h_path)
            h_img = (h_img * 255).astype(np.uint8)
            # Load output image
            o_img = _load_img_array(o_path)
            o_img = (o_img * 255).astype(np.uint8)
            # Get psnr of bicubic, predicted
            psnr_value = psnr(_rgb2ycbcr(h_img)[:, :, 0],
                              _rgb2ycbcr(o_img)[:, :, 0],
                              4)
            count += 1
            avg_psnr_value += psnr_value
    return (avg_psnr_value / count)

# Initialize variables
data_dir='../dataset'
root_dir='MSRResNet'
save=False
# --------------------------------
# basic settings
# --------------------------------
testsets = f'{data_dir}/DIV2K'
testset_L = f'DIV2K_valid_LR_bicubic'
testset_H = f'DIV2K_valid_HR'
L_folder = os.path.join(testsets, testset_L, 'X4')
H_folder = os.path.join(testsets, testset_H)
E_folder = os.path.join('results')
P_folder = os.path.join('pruned')
torch.cuda.current_device()
torch.cuda.empty_cache()
torch.backends.cudnn.benchmark = True
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger = get_logger('AIM-track')
#model_path = os.path.join(root_dir, 'MSRResNetx4_model', 'MSRResNetx4.pth')
#model_name = 'MSRResNet'
#model_name = 'CARN'
#model_path = os.path.join('CARN', 'CARN_model', 'v2_e10000.pth')#'carn.pth')
#rewind_path = os.path.join('CARN', 'CARN_model', 'v2_e0.pth')
#model_name = 'PCARN32'
#model_path = os.path.join('CARN', 'CARN_model', 'v8_e8000.pth')#'carn.pth')
#model_path = os.path.join('srdensenet_v0.pth')
#model_name = 'SRDenseNet'
#model_path = os.path.join(root_dir, 'MSRResNetx4_model', 'PrunedMSRResNetx4_v2_10000.pth')
#rewind_path = os.path.join(root_dir, 'MSRResNetx4_model', 'PrunedMSRResNetx4_v2_0.pth')
#model_name = 'MSRResNet'
model_name = 'PRDN_58'
model_path = os.path.join('PRDN_v8_e2000.pth')
rewind_path = os.path.join('PRDN_v8_e0.pth')

LogHandlers setup!


In [4]:
def load_model(model_path, device, name='MSRResNet'):
    # --------------------------------
    # load model
    # --------------------------------
    if name == 'MSRResNet':
        model = MSRResNet(in_nc=3, out_nc=3, nf=64, nb=16, upscale=4)
        model.load_state_dict(torch.load(model_path), strict=True)
        model.eval()
        for k, v in model.named_parameters():
            v.requires_grad = False
        model = model.to(device)
        return model
    elif name == 'PMSRResNet':
        model = torch.nn.DataParallel(PrunedMSRResNet())
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint['net'], strict=True)
        model.eval()
        for k, v in model.named_parameters():
            v.requires_grad = False
        model = model.to(device)
        return model
    elif name == 'CARN':
        model = torch.nn.DataParallel(CARN(multi_scale=4, group=1))
        # Load pretrained model
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint['net'])
        # Return net
        model.eval()
        for k, v in model.named_parameters():
            v.requires_grad = False
        model = model.to(device)
        return model
    elif name == 'PCARN':
        model = CARN(multi_scale=4, group=1, channel_cnt=12)
        # Load pretrained model
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint['net'])
        # Return net
        model.eval()
        for k, v in model.named_parameters():
            v.requires_grad = False
        model = model.to(device)
        return model
    elif name == 'PCARN32':
        model = torch.nn.DataParallel(CARN(multi_scale=4, group=1, channel_cnt=32))
        # Load pretrained model
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint['net'])
        # Return net
        model.eval()
        for k, v in model.named_parameters():
            v.requires_grad = False
        model = model.to(device)
        modle = model.module
        return model
    elif name == 'SRDenseNet':
        model = torch.load(model_path)['model']
        # Return net
        model.eval()
        for k, v in model.named_parameters():
            v.requires_grad = False
        model = model.to(device)
        return model
    elif name == 'RDN':
        model = RDN(scale_factor=4, num_channels=3, num_features=64, growth_rate=64, num_blocks=16, num_layers=8)
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint['net'])
        model.eval()
        for k, v in model.named_parameters():
            v.requires_grad = False
        model = model.to(device)
        return model
    elif name == 'PRDN_58':
        model = torch.nn.DataParallel(RDN(scale_factor=4, num_channels=3, num_features=58, growth_rate=58, num_blocks=16, num_layers=8))
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint['net'])
        model.eval()
        for k, v in model.named_parameters():
            v.requires_grad = False
        model = model.to(device)
        model = model.module
        return model

model = load_model(model_path, device, model_name)
logger.info(f'Params number: {count_num_of_parameters(model)}')

20-07-09 01:28:35.212 : Params number: 18292159


In [5]:
def test(model, L_folder, out_folder, logger, save, ensemble=False):
    # --------------------------------
    # read image
    # --------------------------------
    util.mkdir(out_folder)

    # record PSNR, runtime
    test_results = OrderedDict()
    test_results['runtime'] = []

    logger.info(L_folder)
    logger.info(out_folder)
    idx = 0

    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    for img in util.get_image_paths(L_folder):

        # --------------------------------
        # (1) img_L
        # --------------------------------
        idx += 1
        img_name, ext = os.path.splitext(os.path.basename(img))
        logger.info('{:->4d}--> {:>10s}'.format(idx, img_name+ext))

        img_L = util.imread_uint(img, n_channels=3)
        img_L = util.uint2tensor4(img_L)
        img_L = img_L.to(device)

        if ensemble:
            start.record()
            # Original
            img_E = model(img_L)
            # H flip
            img_L_hflip = img_L.flip(-1)
            img_E += model(img_L_hflip).flip(-1)
            # V flip
            img_E += model(img_L.flip(-2)).flip(-2)
            # H, V flip
            img_E += model(img_L_hflip.flip(-2)).flip(-1).flip(-2)
            # Rotation 90
            img_E += model(img_L.rot90(1, (-1, -2))).rot90(3, (-1, -2))
            # Rotation 180
            img_E += model(img_L.rot90(2, (-1, -2))).rot90(2, (-1, -2))
            # Rotation 270
            img_E += model(img_L.rot90(3, (-1, -2))).rot90(1, (-1, -2))
            img_E /= 7
            end.record()
        else:
            start.record()
            img_E = model(img_L)
            end.record()
        torch.cuda.synchronize()
        test_results['runtime'].append(start.elapsed_time(end))  # milliseconds

        # --------------------------------
        # (2) img_E
        # --------------------------------
        img_E = util.tensor2uint(img_E)

        if save:
            new_name = '{:3d}'.format(int(img_name.split('x')[0]))
            path = os.path.join(out_folder, new_name+ext)
            logger.info('Save {:4d} to {:10s}'.format(idx, path))
            util.imsave(img_E, path)
    ave_runtime = sum(test_results['runtime']) / len(test_results['runtime']) / 1000.0
    logger.info('------> Average runtime of ({}) is : {:.6f} seconds'.format(L_folder, ave_runtime))

test(model, L_folder, E_folder, logger, True, True)
logger.info(f'PSNR before pruning {calculate_psnr(H_folder, E_folder, logger)}')

20-07-09 01:28:36.073 : ../dataset/DIV2K/DIV2K_valid_LR_bicubic/X4
20-07-09 01:28:36.074 : results
20-07-09 01:28:36.075 : ---1--> 0801x4.png
20-07-09 01:28:42.807 : Save    1 to results/801.png
20-07-09 01:28:42.968 : ---2--> 0802x4.png
20-07-09 01:28:46.673 : Save    2 to results/802.png
20-07-09 01:28:46.825 : ---3--> 0803x4.png
20-07-09 01:28:54.036 : Save    3 to results/803.png
20-07-09 01:28:54.212 : ---4--> 0804x4.png
20-07-09 01:29:00.117 : Save    4 to results/804.png
20-07-09 01:29:00.257 : ---5--> 0805x4.png
20-07-09 01:29:04.468 : Save    5 to results/805.png
20-07-09 01:29:04.626 : ---6--> 0806x4.png
20-07-09 01:29:08.359 : Save    6 to results/806.png
20-07-09 01:29:08.494 : ---7--> 0807x4.png
20-07-09 01:29:14.313 : Save    7 to results/807.png
20-07-09 01:29:14.444 : ---8--> 0808x4.png
20-07-09 01:29:18.189 : Save    8 to results/808.png
20-07-09 01:29:18.316 : ---9--> 0809x4.png
20-07-09 01:29:22.072 : Save    9 to results/809.png
20-07-09 01:29:22.209 : --10--> 0810x

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
from sklearn.manifold import TSNE
from tqdm import tqdm

plt.rcParams["figure.figsize"] = (14,4)

def visualize_both(name, norm):
    fig, ax = plt.subplots(1, 2)
    ax[0].set_xscale('linear')
    visualize_norm_stat(f"{name}_lin", norm, ax[0])
    ax[1].set_xscale('log')
    visualize_norm_stat(f"{name}_log", norm.abs(), ax[1])
    plt.show()
    #plt.savefig(f"images/{name}.png", dpi=300)

def visualize_norm_stat(name, norm, ax):
    # Draw each point
    y = np.zeros(np.shape(norm))
    ax.plot(norm, y, '|')
    estimator = stats.gaussian_kde(norm, bw_method='silverman')
    # Draw kernel density estimate
    X = np.arange(norm.min() * 1.1, norm.max() * 1.1, 0.1)
    K = estimator(X)
    ax.plot(X, K, label=f'{name}')
    # Set other things
    ax.legend(loc='best')

filter_unprune = ['module.sub_mean.shifter', 'module.add_mean.shifter']
for name, module in model.named_modules():
    module_name = type(module).__name__
    unprune = 'Conv2d' not in module_name or name in filter_unprune
    if unprune: continue
    visualize_both(name, module.weight.sum(-1).sum(-1).sum(-1).detach().cpu())

In [6]:
# RDN naive l2 channel pruning
import copy
from collections import OrderedDict

rewind_model = load_model(rewind_path, device, model_name)
model_new = copy.deepcopy(model)
# Kind of filter
filter_unprune = []
filter_in = ['sfe1']
filter_out = ['output']
amount = 0.10344827586206896551724137931034
# Pruning!
mask_dict = OrderedDict()
logger.info('Params number(Before prune): {}'.format(count_num_of_parameters(model_new)))
for name, module in model_new.named_modules():
    # Prune or not
    module_name = type(module).__name__
    unprune = 'Conv2d' not in module_name or name in filter_unprune
    logger.debug("=" * 20)
    logger.debug(f"{name}, {module_name}: Prune-{not unprune}")
    if unprune: continue
    # Prepare pruning
    rewind_module = None
    for n, m in rewind_model.named_modules():
        if name == n:
            rewind_module = m
            break
    pre_module_weight_shape = module.weight.shape
    pre_module_bias_shape = module.bias.shape
    # Prune in_channel
    if name not in filter_in:
        temp_module = copy.deepcopy(module)
        prune.ln_structured(temp_module, 'weight', amount=amount, n=2, dim=1)
        prune.remove(temp_module, 'weight')
        in_channel_mask = temp_module.weight.sum(-1).sum(-1).sum(0) != 0
    else:
        in_channel_mask = torch.ones(module.weight.shape[1], dtype=torch.bool).to(device)
    # Prune out_channel
    if name not in filter_out:
        temp_module = copy.deepcopy(module)
        prune.ln_structured(temp_module, 'weight', amount=amount, n=2, dim=0)
        prune.remove(temp_module, 'weight')
        out_channel_mask = temp_module.weight.sum(-1).sum(-1).sum(-1) != 0
    else:
        out_channel_mask = torch.ones(module.weight.shape[0], dtype=torch.bool).to(device)
    module.weight = torch.nn.Parameter(rewind_module.weight[:, in_channel_mask])
    module.weight = torch.nn.Parameter(module.weight[out_channel_mask, :])
    module.bias = torch.nn.Parameter(rewind_module.bias[out_channel_mask])
    logger.debug(f"mask_index.shape: {out_channel_mask.shape}")
    logger.debug(f"module.weight.shape: {pre_module_weight_shape} --> {module.weight.shape}")
    logger.debug(f"module.bias.shape: {pre_module_bias_shape} --> {module.bias.shape}")
logger.info('Params number(After prune): {}'.format(count_num_of_parameters(model_new)))

20-07-09 01:37:34.209 : Params number(Before prune): 18292159
20-07-09 01:37:34.211 : , RDN: Prune-False
20-07-09 01:37:34.213 : sfe1, Conv2d: Prune-True
20-07-09 01:37:34.221 : mask_index.shape: torch.Size([58])
20-07-09 01:37:34.222 : module.weight.shape: torch.Size([58, 3, 3, 3]) --> torch.Size([52, 3, 3, 3])
20-07-09 01:37:34.222 : module.bias.shape: torch.Size([58]) --> torch.Size([52])
20-07-09 01:37:34.224 : sfe2, Conv2d: Prune-True
20-07-09 01:37:34.228 : mask_index.shape: torch.Size([58])
20-07-09 01:37:34.229 : module.weight.shape: torch.Size([58, 58, 3, 3]) --> torch.Size([52, 52, 3, 3])
20-07-09 01:37:34.230 : module.bias.shape: torch.Size([58]) --> torch.Size([52])
20-07-09 01:37:34.231 : rdbs, ModuleList: Prune-False
20-07-09 01:37:34.233 : rdbs.0, RDB: Prune-False
20-07-09 01:37:34.234 : rdbs.0.layers, Sequential: Prune-False
20-07-09 01:37:34.236 : rdbs.0.layers.0, DenseLayer: Prune-False
20-07-09 01:37:34.237 : rdbs.0.layers.0.conv, Conv2d: Prune-True
20-07-09 01:37:34

In [None]:
# Prune CARN naively
import copy
from collections import OrderedDict

rewind_model = load_model(rewind_path, device, model_name)
rewind_model = rewind_model.module
model_new = copy.deepcopy(model)
model_new = model_new.module
# Kind of filter
filter_unprune = ['sub_mean.shifter', 'add_mean.shifter']
# Pruning!
mask_dict = OrderedDict()
logger.info('Params number(Before prune): {}'.format(count_num_of_parameters(model_new)))
for name, module in model_new.named_modules():
    # Prune or not
    module_name = type(module).__name__
    unprune = 'Conv2d' not in module_name or name in filter_unprune
    logger.debug("=" * 20)
    logger.debug(f"{name}, {module_name}: Prune-{not unprune}")
    if unprune: continue
    # Prepare pruning
    rewind_module = None
    for n, m in rewind_model.named_modules():
        if name == n:
            rewind_module = m
            break
    pre_module_weight_shape = module.weight.shape
    pre_module_bias_shape = module.bias.shape
    # Prune in_channel
    if name not in ['entry']:
        temp_module = copy.deepcopy(module)
        prune.ln_structured(temp_module, 'weight', amount=0.90625, n=2, dim=1)
        prune.remove(temp_module, 'weight')
        in_channel_mask = temp_module.weight.sum(-1).sum(-1).sum(0) != 0
    else:
        in_channel_mask = torch.ones(module.weight.shape[1], dtype=torch.bool).to(device)
    # Prune out_channel
    if name not in ['exit']:
        temp_module = copy.deepcopy(module)
        prune.ln_structured(temp_module, 'weight', amount=0.90625, n=2, dim=0)
        prune.remove(temp_module, 'weight')
        out_channel_mask = temp_module.weight.sum(-1).sum(-1).sum(-1) != 0
    else:
        out_channel_mask = torch.ones(module.weight.shape[0], dtype=torch.bool).to(device)
    module.weight = torch.nn.Parameter(rewind_module.weight[:, in_channel_mask])
    module.weight = torch.nn.Parameter(module.weight[out_channel_mask, :])
    module.bias = torch.nn.Parameter(rewind_module.bias[out_channel_mask])
    logger.debug(f"mask_index.shape: {out_channel_mask.shape}")
    logger.debug(f"module.weight.shape: {pre_module_weight_shape} --> {module.weight.shape}")
    logger.debug(f"module.bias.shape: {pre_module_bias_shape} --> {module.bias.shape}")
logger.info('Params number(After prune): {}'.format(count_num_of_parameters(model_new)))

In [None]:
# Prune MSRResNet

import copy

model_new = copy.deepcopy(model)
logger.info('Params number(Before prune): {}'.format(count_num_of_parameters(model_new)))
pre_mask_index = torch.ones(3, dtype=torch.bool).to(device)
conv_first_mask_index = None
for name, module in model_new.named_modules():
    if 'conv' in name:
        if name == 'upconv2':
            logger.debug("=" * 20)
            logger.debug(f"{name}: Unpruned")
            continue
        # DEBUG ----------------------------------------------------
        logger.debug("=" * 20)
        logger.debug(f"{name}: Pruned")
        pre_module_weight_shape = module.weight.shape
        pre_module_bias_shape = module.bias.shape
        # DEBUG ----------------------------------------------------
        # Prune in_channel
        if name == 'upconv1':
            logger.debug(f"conv_first_mask_index.shape: {conv_first_mask_index.shape}")
            module.weight = torch.nn.Parameter(module.weight[:, conv_first_mask_index])
            # DEBUG ----------------------------------------------------
            logger.debug(f"mask_index.shape: {mask_index.shape}")
            logger.debug(f"module.weight.shape: {pre_module_weight_shape} --> {module.weight.shape}")
            logger.debug(f"module.bias.shape: {pre_module_bias_shape} --> {module.bias.shape}")
            # DEBUG ----------------------------------------------------
            continue
        if name not in ['conv_first', 'HRconv']: # Without first
            logger.debug(f"pre_mask_index.shape: {pre_mask_index.shape}")
            module.weight = torch.nn.Parameter(module.weight[:, pre_mask_index])
        # Get pruning mask
        prune.ln_structured(module, 'weight', amount=0.9, n=2, dim=0)
        prune.remove(module, 'weight')
        mask_index = module.weight.sum(-1).sum(-1).sum(-1) != 0
        pre_mask_index = mask_index
        if name == 'conv_first':
            conv_first_mask_index = mask_index
        # Prune out_channel
        if name not in ['conv_last']: # Without last
            module.weight = torch.nn.Parameter(module.weight[mask_index, :])
            module.bias = torch.nn.Parameter(module.bias[mask_index])
            # DEBUG ----------------------------------------------------
            logger.debug(f"mask_index.shape: {mask_index.shape}")
            logger.debug(f"module.weight.shape: {pre_module_weight_shape} --> {module.weight.shape}")
            logger.debug(f"module.bias.shape: {pre_module_bias_shape} --> {module.bias.shape}")
            # DEBUG ----------------------------------------------------
    else:
        logger.debug("=" * 20)
        logger.debug(f"{name}: Unpruned")
logger.info('Params number(After prune): {}'.format(count_num_of_parameters(model_new)))

In [None]:
test(model_new, L_folder, P_folder, logger, True)
#test(model, L_folder, P_folder, logger, True)
logger.info(f'PSNR after pruning {calculate_psnr(H_folder, P_folder, logger)}')

In [7]:
pruned_model_path = "PRDN_52_rewinded2.pth"#"MSRResNet/MSRResNetx4_model/PrunedMSRResNetx4.pth"
torch.save({
    'net': model_new.state_dict()
}, pruned_model_path)