In [4]:
import os.path
import logging
import time
import torch
import torch.nn.utils.prune as prune
import numpy as np

from torch.autograd import Variable
from PIL import Image
from collections import OrderedDict
from pathlib import Path
from tqdm import tqdm

from MSRResNet.utils import utils_logger
from MSRResNet.utils import utils_image as util
from MSRResNet.SRResNet import MSRResNet


def get_logger(name):
    # Set logger
    utils_logger.logger_info(name, log_path=f'{name}.log')
    logger = logging.getLogger(name)
    logger.setLevel(logging.DEBUG)
    return logger

def load_model(model_path, device):
    # --------------------------------
    # load model
    # --------------------------------
    model = MSRResNet(in_nc=3, out_nc=3, nf=64, nb=16, upscale=4)
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    for k, v in model.named_parameters():
        v.requires_grad = False
    model = model.to(device)
    return model

def count_num_of_parameters(model):
    number_parameters = sum(map(lambda x: x.numel(), model.parameters()))
    return number_parameters

def test(model, L_folder, out_folder, logger, save):
    # --------------------------------
    # read image
    # --------------------------------
    util.mkdir(out_folder)

    # record PSNR, runtime
    test_results = OrderedDict()
    test_results['runtime'] = []

    logger.info(L_folder)
    logger.info(out_folder)
    idx = 0

    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    for img in util.get_image_paths(L_folder):

        # --------------------------------
        # (1) img_L
        # --------------------------------
        idx += 1
        img_name, ext = os.path.splitext(os.path.basename(img))
        logger.info('{:->4d}--> {:>10s}'.format(idx, img_name+ext))

        img_L = util.imread_uint(img, n_channels=3)
        img_L = util.uint2tensor4(img_L)
        img_L = img_L.to(device)

        start.record()
        img_E = model(img_L)
        end.record()
        torch.cuda.synchronize()
        test_results['runtime'].append(start.elapsed_time(end))  # milliseconds

        # --------------------------------
        # (2) img_E
        # --------------------------------
        img_E = util.tensor2uint(img_E)

        if save:
            new_name = '{:3d}'.format(int(img_name.split('x')[0]))
            path = os.path.join(out_folder, new_name+ext)
            logger.info('Save {:4d} to {:10s}'.format(idx, path))
            util.imsave(img_E, path)
    ave_runtime = sum(test_results['runtime']) / len(test_results['runtime']) / 1000.0
    logger.info('------> Average runtime of ({}) is : {:.6f} seconds'.format(L_folder, ave_runtime))

def _load_img_array(path, color_mode='RGB',
                    channel_mean=None, modcrop=[0, 0, 0, 0]):
    '''Load an image using PIL and convert it into specified color space,
    and return it as an numpy array.
    https://github.com/fchollet/keras/blob/master/keras/preprocessing/image.py
    The code is modified from Keras.preprocessing.image.load_img, img_to_array.
    '''
    # Load image
    img = Image.open(path)
    if color_mode == 'RGB':
        cimg = img.convert('RGB')
        x = np.asarray(cimg, dtype='float32')
    elif color_mode == 'YCbCr' or color_mode == 'Y':
        cimg = img.convert('YCbCr')
        x = np.asarray(cimg, dtype='float32')
        if color_mode == 'Y':
            x = x[:, :, 0:1]
    # Normalize To 0-1
    x *= 1.0 / 255.0
    if channel_mean:
        x[:, :, 0] -= channel_mean[0]
        x[:, :, 1] -= channel_mean[1]
        x[:, :, 2] -= channel_mean[2]
    if modcrop[0] * modcrop[1] * modcrop[2] * modcrop[3]:
        x = x[modcrop[0]:-modcrop[1], modcrop[2]:-modcrop[3], :]
    return x

def _rgb2ycbcr(img, maxVal=255):
    # Same as MATLAB's rgb2ycbcr
    # Updated at 03/14/2017
    # Not tested for cb and cr
    O = np.array([[16],
                  [128],
                  [128]])
    T = np.array([[0.256788235294118, 0.504129411764706, 0.097905882352941],
                  [-0.148223529411765, -0.290992156862745, 0.439215686274510],
                  [0.439215686274510, -0.367788235294118, -0.071427450980392]])
    if maxVal == 1:
        O = O / 255.0
    t = np.reshape(img, (img.shape[0] * img.shape[1], img.shape[2]))
    t = np.dot(t, np.transpose(T))
    t[:, 0] += O[0]
    t[:, 1] += O[1]
    t[:, 2] += O[2]
    ycbcr = np.reshape(t, [img.shape[0], img.shape[1], img.shape[2]])
    return ycbcr

def psnr(y_true, y_pred, shave_border=4):
    '''
        Input must be 0-255, 2D
    '''
    target_data = np.array(y_true, dtype=np.float32)
    ref_data = np.array(y_pred, dtype=np.float32)
    diff = ref_data - target_data
    if shave_border > 0:
        diff = diff[shave_border:-shave_border, shave_border:-shave_border]
    rmse = np.sqrt(np.mean(np.power(diff, 2)))
    return 20 * np.log10(255. / rmse)

def calculate_psnr(testsets_H, testset_O, logger):
    count = 0
    avg_psnr_value = 0.0
    h_list = sorted(list(Path(testsets_H).glob('*.png')))
    o_list = sorted(list(Path(testset_O).glob('*.png')))
    for h_path, o_path in tqdm(zip(h_list, o_list)):
        with torch.no_grad():
            # Load label image
            h_img = _load_img_array(h_path)
            h_img = (h_img * 255).astype(np.uint8)
            # Load output image
            o_img = _load_img_array(o_path)
            o_img = (o_img * 255).astype(np.uint8)
            # Get psnr of bicubic, predicted
            psnr_value = psnr(_rgb2ycbcr(h_img)[:, :, 0],
                              _rgb2ycbcr(o_img)[:, :, 0],
                              4)
            count += 1
            avg_psnr_value += psnr_value
    return (avg_psnr_value / count)

# Initialize variables
data_dir='../dataset'
root_dir='MSRResNet'
save=False
# --------------------------------
# basic settings
# --------------------------------
testsets = f'{data_dir}/DIV2K'
testset_L = f'DIV2K_valid_LR_bicubic'
testset_H = f'DIV2K_valid_HR'
L_folder = os.path.join(testsets, testset_L, 'X4')
H_folder = os.path.join(testsets, testset_H)
E_folder = os.path.join('results')
P_folder = os.path.join('pruned')
torch.cuda.current_device()
torch.cuda.empty_cache()
torch.backends.cudnn.benchmark = True
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger = get_logger('AIM-track')
model_path = os.path.join(root_dir, 'MSRResNetx4_model', 'MSRResNetx4.pth')

LogHandlers exist!


In [5]:
model = load_model(model_path, device)
logger.info(f'Params number: {count_num_of_parameters(model)}')

20-06-25 12:25:26.763 : Params number: 1517571


In [6]:
test(model, L_folder, E_folder, logger, True)
logger.info(f'PSNR before pruning {calculate_psnr(H_folder, E_folder, logger)}')

20-06-25 12:25:27.642 : ../dataset/DIV2K/DIV2K_valid_LR_bicubic/X4
20-06-25 12:25:27.644 : results
20-06-25 12:25:27.647 : ---1--> 0801x4.png
20-06-25 12:25:29.754 : Save    1 to results/801.png
20-06-25 12:25:29.929 : ---2--> 0802x4.png
20-06-25 12:25:30.077 : Save    2 to results/802.png
20-06-25 12:25:30.245 : ---3--> 0803x4.png
20-06-25 12:25:32.267 : Save    3 to results/803.png
20-06-25 12:25:32.455 : ---4--> 0804x4.png
20-06-25 12:25:33.845 : Save    4 to results/804.png
20-06-25 12:25:33.997 : ---5--> 0805x4.png
20-06-25 12:25:34.149 : Save    5 to results/805.png
20-06-25 12:25:34.332 : ---6--> 0806x4.png
20-06-25 12:25:34.466 : Save    6 to results/806.png
20-06-25 12:25:34.617 : ---7--> 0807x4.png
20-06-25 12:25:35.964 : Save    7 to results/807.png
20-06-25 12:25:36.099 : ---8--> 0808x4.png
20-06-25 12:25:36.231 : Save    8 to results/808.png
20-06-25 12:25:36.361 : ---9--> 0809x4.png
20-06-25 12:25:36.497 : Save    9 to results/809.png
20-06-25 12:25:36.662 : --10--> 0810x

In [7]:
import copy

model_new = copy.deepcopy(model)
logger.info('Params number(Before prune): {}'.format(count_num_of_parameters(model_new)))
pre_mask_index = torch.ones(3, dtype=torch.bool).to(device)
for name, module in model_new.named_modules():
    if 'conv' in name:
        if name in ['upconv1', 'upconv2']:
            continue
        prune.ln_structured(module, 'weight', amount=0.5, n=2, dim=0)
        prune.remove(module, 'weight')
        mask_index = module.weight.sum(-1).sum(-1).sum(-1) != 0
        # DEBUG ----------------------------------------------------
        logger.debug("=" * 20)
        logger.debug(f"{name}: Pruned")
        logger.debug(f"pre_mask_index.shape: {pre_mask_index.shape}")
        logger.debug(f"mask_index.shape: {mask_index.shape}")
        pre_module_weight_shape = module.weight.shape
        pre_module_bias_shape = module.bias.shape
        # DEBUG ----------------------------------------------------
        if name not in ['conv_first'] + [f"recon_trunk.{i}.conv1" for i in range(16)] + ['HRconv']:
            module.weight = torch.nn.Parameter(module.weight[:, pre_mask_index])
        if name not in ['conv_first'] + [f"recon_trunk.{i}.conv2" for i in range(16)] + ['conv_last']:
            module.weight = torch.nn.Parameter(module.weight[mask_index, :])
            module.bias = torch.nn.Parameter(module.bias[mask_index])
        # DEBUG ----------------------------------------------------
        logger.debug(f"module.weight.shape: {pre_module_weight_shape} --> {module.weight.shape}")
        logger.debug(f"module.bias.shape: {pre_module_bias_shape} --> {module.bias.shape}")
        # DEBUG ----------------------------------------------------
        pre_mask_index = mask_index
    else:
        logger.debug("=" * 20)
        logger.debug(f"{name}: Unpruned")
logger.info('Params number(After prune): {}'.format(count_num_of_parameters(model_new)))

20-06-25 12:27:27.544 : Params number(Before prune): 1517571
20-06-25 12:27:27.547 : : Unpruned
20-06-25 12:27:27.551 : conv_first: Pruned
20-06-25 12:27:27.552 : pre_mask_index.shape: torch.Size([3])
20-06-25 12:27:27.553 : mask_index.shape: torch.Size([64])
20-06-25 12:27:27.553 : module.weight.shape: torch.Size([64, 3, 3, 3]) --> torch.Size([64, 3, 3, 3])
20-06-25 12:27:27.554 : module.bias.shape: torch.Size([64]) --> torch.Size([64])
20-06-25 12:27:27.556 : recon_trunk: Unpruned
20-06-25 12:27:27.557 : recon_trunk.0: Unpruned
20-06-25 12:27:27.562 : recon_trunk.0.conv1: Pruned
20-06-25 12:27:27.563 : pre_mask_index.shape: torch.Size([64])
20-06-25 12:27:27.564 : mask_index.shape: torch.Size([64])
20-06-25 12:27:27.566 : module.weight.shape: torch.Size([64, 64, 3, 3]) --> torch.Size([32, 64, 3, 3])
20-06-25 12:27:27.567 : module.bias.shape: torch.Size([64]) --> torch.Size([32])
20-06-25 12:27:27.569 : recon_trunk.0.conv2: Pruned
20-06-25 12:27:27.570 : pre_mask_index.shape: torch.Si

In [8]:
test(model_new, L_folder, P_folder, logger, True)
logger.info(f'PSNR after pruning {calculate_psnr(H_folder, P_folder, logger)}')

20-06-25 12:27:38.389 : ../dataset/DIV2K/DIV2K_valid_LR_bicubic/X4
20-06-25 12:27:38.391 : pruned
20-06-25 12:27:38.394 : ---1--> 0801x4.png
20-06-25 12:27:39.249 : Save    1 to pruned/801.png
20-06-25 12:27:39.420 : ---2--> 0802x4.png
20-06-25 12:27:39.529 : Save    2 to pruned/802.png
20-06-25 12:27:39.684 : ---3--> 0803x4.png
20-06-25 12:27:40.631 : Save    3 to pruned/803.png
20-06-25 12:27:40.814 : ---4--> 0804x4.png
20-06-25 12:27:41.585 : Save    4 to pruned/804.png
20-06-25 12:27:41.732 : ---5--> 0805x4.png
20-06-25 12:27:41.853 : Save    5 to pruned/805.png
20-06-25 12:27:42.031 : ---6--> 0806x4.png
20-06-25 12:27:42.138 : Save    6 to pruned/806.png
20-06-25 12:27:42.285 : ---7--> 0807x4.png
20-06-25 12:27:42.921 : Save    7 to pruned/807.png
20-06-25 12:27:43.046 : ---8--> 0808x4.png
20-06-25 12:27:43.152 : Save    8 to pruned/808.png
20-06-25 12:27:43.291 : ---9--> 0809x4.png
20-06-25 12:27:43.398 : Save    9 to pruned/809.png
20-06-25 12:27:43.549 : --10--> 0810x4.png
20-0