In [7]:
import os.path
import logging
import time
import torch
import torch.nn.utils.prune as prune
import torchvision.transforms.functional as TVF
import numpy as np
import copy

from collections import OrderedDict
from torch.autograd import Variable
from PIL import Image
from collections import OrderedDict
from pathlib import Path
from tqdm import tqdm

from MSRResNet.utils import utils_logger
from MSRResNet.utils import utils_image as util
from MSRResNet.SRResNet import MSRResNet, PrunedMSRResNet
from CARN import CARN
from rdn import RDN
from RRDBNet_arch import RRDBNet
from RCAN import RCAN


def get_logger(name):
    # Set logger
    utils_logger.logger_info(name, log_path=f'{name}.log')
    logger = logging.getLogger(name)
    logger.setLevel(logging.DEBUG)
    return logger

def count_num_of_parameters(model):
    number_parameters = sum(map(lambda x: x.numel(), model.parameters()))
    return number_parameters

def _load_img_array(path, color_mode='RGB',
                    channel_mean=None, modcrop=[0, 0, 0, 0]):
    '''Load an image using PIL and convert it into specified color space,
    and return it as an numpy array.
    https://github.com/fchollet/keras/blob/master/keras/preprocessing/image.py
    The code is modified from Keras.preprocessing.image.load_img, img_to_array.
    '''
    # Load image
    img = Image.open(path)
    if color_mode == 'RGB':
        cimg = img.convert('RGB')
        x = np.asarray(cimg, dtype='float32')
    elif color_mode == 'YCbCr' or color_mode == 'Y':
        cimg = img.convert('YCbCr')
        x = np.asarray(cimg, dtype='float32')
        if color_mode == 'Y':
            x = x[:, :, 0:1]
    # Normalize To 0-1
    x *= 1.0 / 255.0
    if channel_mean:
        x[:, :, 0] -= channel_mean[0]
        x[:, :, 1] -= channel_mean[1]
        x[:, :, 2] -= channel_mean[2]
    if modcrop[0] * modcrop[1] * modcrop[2] * modcrop[3]:
        x = x[modcrop[0]:-modcrop[1], modcrop[2]:-modcrop[3], :]
    return x

def _rgb2ycbcr(img, maxVal=255):
    # Same as MATLAB's rgb2ycbcr
    # Updated at 03/14/2017
    # Not tested for cb and cr
    O = np.array([[16],
                  [128],
                  [128]])
    T = np.array([[0.256788235294118, 0.504129411764706, 0.097905882352941],
                  [-0.148223529411765, -0.290992156862745, 0.439215686274510],
                  [0.439215686274510, -0.367788235294118, -0.071427450980392]])
    if maxVal == 1:
        O = O / 255.0
    t = np.reshape(img, (img.shape[0] * img.shape[1], img.shape[2]))
    t = np.dot(t, np.transpose(T))
    t[:, 0] += O[0]
    t[:, 1] += O[1]
    t[:, 2] += O[2]
    ycbcr = np.reshape(t, [img.shape[0], img.shape[1], img.shape[2]])
    return ycbcr

def psnr(y_true, y_pred, shave_border=4):
    '''
        Input must be 0-255, 2D
    '''
    target_data = np.array(y_true, dtype=np.float32)
    ref_data = np.array(y_pred, dtype=np.float32)
    diff = ref_data - target_data
    if shave_border > 0:
        diff = diff[shave_border:-shave_border, shave_border:-shave_border]
    rmse = np.sqrt(np.mean(np.power(diff, 2)))
    return 20 * np.log10(255. / rmse)

def calculate_psnr(testsets_H, testset_O, logger):
    count = 0
    avg_psnr_value = 0.0
    h_list = sorted(list(Path(testsets_H).glob('*.png')))
    o_list = sorted(list(Path(testset_O).glob('*.png')))
    for h_path, o_path in tqdm(zip(h_list, o_list)):
        with torch.no_grad():
            # Load label image
            h_img = _load_img_array(h_path)
            h_img = (h_img * 255).astype(np.uint8)
            # Load output image
            o_img = _load_img_array(o_path)
            o_img = (o_img * 255).astype(np.uint8)
            # Get psnr of bicubic, predicted
            psnr_value = psnr(h_img,#_rgb2ycbcr(h_img)[:, :, 0],
                              o_img,#_rgb2ycbcr(o_img)[:, :, 0],
                              4)
            count += 1
            avg_psnr_value += psnr_value
    return (avg_psnr_value / count)

# Initialize variables
data_dir='../dataset'
root_dir='MSRResNet'
save=False
# --------------------------------
# basic settings
# --------------------------------
testsets = f'{data_dir}/DIV2K'
testset_L = f'DIV2K_valid_LR_bicubic'
testset_H = f'DIV2K_valid_HR'
L_folder = os.path.join(testsets, testset_L, 'X4')
H_folder = os.path.join(testsets, testset_H)
E_folder = os.path.join('results')
P_folder = os.path.join('pruned')
torch.cuda.current_device()
torch.cuda.empty_cache()
torch.backends.cudnn.benchmark = True
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger = get_logger('AIM-track')

model_name = 'RRDBNet_4'
model_path = os.path.join('model/RRDB_v16_e17000.pth')
#rewind_path = os.path.join('RCAN_v2_e0.pth')

LogHandlers exist!


In [8]:
def load_model(model_path, device, name='MSRResNet'):
    # --------------------------------
    # load model
    # --------------------------------
    if name == 'MSRResNet':
        model = MSRResNet(in_nc=3, out_nc=3, nf=64, nb=16, upscale=4)
        model.load_state_dict(torch.load(model_path), strict=True)
        model.eval()
        for k, v in model.named_parameters():
            v.requires_grad = False
        model = model.to(device)
        return model
    elif name == 'PMSRResNet':
        model = torch.nn.DataParallel(PrunedMSRResNet())
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint['net'], strict=True)
        model.eval()
        for k, v in model.named_parameters():
            v.requires_grad = False
        model = model.to(device)
        return model
    elif name == 'CARN':
        model = torch.nn.DataParallel(CARN(multi_scale=4, group=1))
        # Load pretrained model
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint['net'])
        # Return net
        model.eval()
        for k, v in model.named_parameters():
            v.requires_grad = False
        model = model.to(device)
        return model
    elif name == 'PCARN':
        model = CARN(multi_scale=4, group=1, channel_cnt=12)
        # Load pretrained model
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint['net'])
        # Return net
        model.eval()
        for k, v in model.named_parameters():
            v.requires_grad = False
        model = model.to(device)
        return model
    elif name == 'PCARN32':
        #model = torch.nn.DataParallel(CARN(multi_scale=4, group=1, channel_cnt=32))
        model = CARN(multi_scale=4, group=1, channel_cnt=32)
        # Load pretrained model
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint['net'])
        # Return net
        model.eval()
        for k, v in model.named_parameters():
            v.requires_grad = False
        model = model.to(device)
        #modle = model.module
        return model
    elif name == 'PCARN6':
        model = CARN(multi_scale=4, group=1, channel_cnt=6)
        # Load pretrained model
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint['net'])
        # Return net
        model.eval()
        for k, v in model.named_parameters():
            v.requires_grad = False
        model = model.to(device)
        return model
    elif name == 'SRDenseNet':
        model = torch.load(model_path)['model']
        # Return net
        model.eval()
        for k, v in model.named_parameters():
            v.requires_grad = False
        model = model.to(device)
        return model
    elif name == 'RDN':
        model = RDN(scale_factor=4, num_channels=3, num_features=64, growth_rate=64, num_blocks=16, num_layers=8)
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint['net'])
        model.eval()
        for k, v in model.named_parameters():
            v.requires_grad = False
        model = model.to(device)
        return model
    elif name == 'PRDN_58':
        model = torch.nn.DataParallel(RDN(scale_factor=4, num_channels=3, num_features=58, growth_rate=58, num_blocks=16, num_layers=8))
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint['net'])
        model.eval()
        for k, v in model.named_parameters():
            v.requires_grad = False
        model = model.to(device)
        model = model.module
        return model
    elif name == 'PRDN_52':
        model = torch.nn.DataParallel(RDN(scale_factor=4, num_channels=3, num_features=52, growth_rate=52, num_blocks=16, num_layers=8))
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint['net'])
        model.eval()
        for k, v in model.named_parameters():
            v.requires_grad = False
        model = model.to(device)
        model = model.module
        return model
    elif name == 'PRDN_46':
        model = torch.nn.DataParallel(RDN(scale_factor=4, num_channels=3, num_features=46, growth_rate=46, num_blocks=16, num_layers=8))
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint['net'])
        model.eval()
        for k, v in model.named_parameters():
            v.requires_grad = False
        model = model.to(device)
        model = model.module
        return model
    elif name == 'PRDN_40':
        model = torch.nn.DataParallel(RDN(scale_factor=4, num_channels=3, num_features=40, growth_rate=40, num_blocks=16, num_layers=8))
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint['net'])
        model.eval()
        for k, v in model.named_parameters():
            v.requires_grad = False
        model = model.to(device)
        model = model.module
        return model
    elif name == 'PRDN_34':
        model = RDN(scale_factor=4, num_channels=3, num_features=34, growth_rate=34, num_blocks=16, num_layers=8)
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint['net'])
        model.eval()
        for k, v in model.named_parameters():
            v.requires_grad = False
        model = model.to(device)
        return model
    elif name == 'PRDN_32':
        model = torch.nn.DataParallel(RDN(scale_factor=4, num_channels=3, num_features=32, growth_rate=32, num_blocks=16, num_layers=8))
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint['net'])
        model.eval()
        for k, v in model.named_parameters():
            v.requires_grad = False
        model = model.to(device)
        model = model.module
        return model
    elif name == 'PRDN_28':
        model = RDN(scale_factor=4, num_channels=3, num_features=28, growth_rate=28, num_blocks=16, num_layers=8)
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint['net'])
        model.eval()
        for k, v in model.named_parameters():
            v.requires_grad = False
        model = model.to(device)
        return model
    elif name == 'PRDN_22':
        model = RDN(scale_factor=4, num_channels=3, num_features=22, growth_rate=22, num_blocks=16, num_layers=8)
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint['net'])
        model.eval()
        for k, v in model.named_parameters():
            v.requires_grad = False
        model = model.to(device)
        return model
    elif name == 'PRDN_12':
        model = torch.nn.DataParallel(RDN(scale_factor=4, num_channels=3, num_features=12, growth_rate=12, num_blocks=16, num_layers=8))
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint['net'])
        model.eval()
        for k, v in model.named_parameters():
            v.requires_grad = False
        model = model.to(device)
        model = model.module
        return model
    elif name == 'PRDN_6':
        model = torch.nn.DataParallel(RDN(scale_factor=4, num_channels=3, num_features=6, growth_rate=6, num_blocks=16, num_layers=8))
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint['net'])
        model.eval()
        for k, v in model.named_parameters():
            v.requires_grad = False
        model = model.to(device)
        model = model.module
        return model
    elif name == 'RRDBNet':
        model = RRDBNet(3, 3, 64, 23, gc=32)
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint['net'])
        model.eval()
        for k, v in model.named_parameters():
            v.requires_grad = False
        model = model.to(device)
        return model
    elif name == 'RRDBNet_32':
        model = RRDBNet(3, 3, 32, 23, gc=16)
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint['net'])
        model.eval()
        for k, v in model.named_parameters():
            v.requires_grad = False
        model = model.to(device)
        return model
    elif name == 'RRDBNet_16':
        model = RRDBNet(3, 3, 16, 23, gc=8)
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint['net'])
        model.eval()
        for k, v in model.named_parameters():
            v.requires_grad = False
        model = model.to(device)
        return model
    elif name == 'RRDBNet_12':
        model = RRDBNet(3, 3, 12, 23, gc=6)
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint['net'])
        model.eval()
        for k, v in model.named_parameters():
            v.requires_grad = False
        model = model.to(device)
        return model
    elif name == 'RRDBNet_8':
        model = RRDBNet(3, 3, 8, 23, gc=4)
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint['net'])
        model.eval()
        for k, v in model.named_parameters():
            v.requires_grad = False
        model = model.to(device)
        return model
    elif name == 'RRDBNet_4':
        model = RRDBNet(3, 3, 4, 23, gc=2)
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint['net'])
        model.eval()
        for k, v in model.named_parameters():
            v.requires_grad = False
        model = model.to(device)
        return model
    elif name == 'RCAN':
        model = RCAN()
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint['net'])
        model.eval()
        for k, v in model.named_parameters():
            v.requires_grad = False
        model = model.to(device)
        return model

model = load_model(model_path, device, model_name)
print(f'Params number: {count_num_of_parameters(model)}')

Params number: 66227


In [9]:
def test(model, L_folder, out_folder, logger, save, ensemble=False):
    # --------------------------------
    # read image
    # --------------------------------
    util.mkdir(out_folder)

    # record PSNR, runtime
    test_results = OrderedDict()
    test_results['runtime'] = []

    logger.info(L_folder)
    logger.info(out_folder)
    idx = 0

    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    for img in util.get_image_paths(L_folder):

        # --------------------------------
        # (1) img_L
        # --------------------------------
        idx += 1
        img_name, ext = os.path.splitext(os.path.basename(img))
        logger.info('{:->4d}--> {:>10s}'.format(idx, img_name+ext))

        img_L = util.imread_uint(img, n_channels=3)
        img_L = util.uint2tensor4(img_L)
        img_L = img_L.to(device)

        if ensemble:
            start.record()
            # Original
            img_E = model(img_L)
            # Rotation 90
            img_E += model(img_L.rot90(1, (-1, -2))).rot90(3, (-1, -2))
            # Rotation 180
            img_E += model(img_L.rot90(2, (-1, -2))).rot90(2, (-1, -2))
            # Rotation 270
            img_E += model(img_L.rot90(3, (-1, -2))).rot90(1, (-1, -2))
            # H flip
            img_L_hflip = img_L.flip(-1)
            img_E += model(img_L_hflip).flip(-1)
            # H flip + rot90
            img_E += model(img_L_hflip.rot90(1, (-1, -2))).rot90(3, (-1, -2)).flip(-1)
            # H flip + rot180
            img_E += model(img_L_hflip.rot90(2, (-1, -2))).rot90(2, (-1, -2)).flip(-1)
            # H flip + rot270
            img_E += model(img_L_hflip.rot90(3, (-1, -2))).rot90(1, (-1, -2)).flip(-1)
            img_E /= 8
            end.record()
        else:
            start.record()
            img_E = model(img_L)
            end.record()
        torch.cuda.synchronize()
        test_results['runtime'].append(start.elapsed_time(end))  # milliseconds

        # --------------------------------
        # (2) img_E
        # --------------------------------
        img_E = util.tensor2uint(img_E)

        if save:
            new_name = '{:3d}'.format(int(img_name.split('x')[0]))
            path = os.path.join(out_folder, new_name+ext)
            print('Save {:4d} to {:10s}'.format(idx, path))
            util.imsave(img_E, path)
    ave_runtime = sum(test_results['runtime']) / len(test_results['runtime']) / 1000.0
    print('------> Average runtime of ({}) is : {:.6f} seconds'.format(L_folder, ave_runtime))

test(model, L_folder, E_folder, logger, True, True)
print(f'PSNR before pruning {calculate_psnr(H_folder, E_folder, logger)}')

Save    1 to results/801.png
Save    2 to results/802.png
Save    3 to results/803.png
Save    4 to results/804.png
Save    5 to results/805.png
Save    6 to results/806.png
Save    7 to results/807.png
Save    8 to results/808.png
Save    9 to results/809.png
Save   10 to results/810.png
Save   11 to results/811.png
Save   12 to results/812.png
Save   13 to results/813.png
Save   14 to results/814.png
Save   15 to results/815.png
Save   16 to results/816.png
Save   17 to results/817.png
Save   18 to results/818.png
Save   19 to results/819.png
Save   20 to results/820.png
Save   21 to results/821.png
Save   22 to results/822.png
Save   23 to results/823.png
Save   24 to results/824.png
Save   25 to results/825.png
Save   26 to results/826.png
Save   27 to results/827.png
Save   28 to results/828.png
Save   29 to results/829.png
Save   30 to results/830.png
Save   31 to results/831.png
Save   32 to results/832.png
Save   33 to results/833.png
Save   34 to results/834.png
Save   35 to r

0it [00:00, ?it/s]

Save  100 to results/900.png
------> Average runtime of (../dataset/DIV2K/DIV2K_valid_LR_bicubic/X4) is : 0.928608 seconds


100it [00:49,  2.02it/s]

PSNR before pruning 28.18131381073257





In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
from sklearn.manifold import TSNE
from tqdm import tqdm

plt.rcParams["figure.figsize"] = (14,4)

def visualize_both(name, norm):
    fig, ax = plt.subplots(1, 2)
    ax[0].set_xscale('linear')
    visualize_norm_stat(f"{name}_lin", norm, ax[0])
    ax[1].set_xscale('log')
    visualize_norm_stat(f"{name}_log", norm.abs(), ax[1])
    plt.show()
    #plt.savefig(f"images/{name}.png", dpi=300)

def visualize_norm_stat(name, norm, ax):
    # Draw each point
    y = np.zeros(np.shape(norm))
    ax.plot(norm, y, '|')
    estimator = stats.gaussian_kde(norm, bw_method='silverman')
    # Draw kernel density estimate
    X = np.arange(norm.min() * 1.1, norm.max() * 1.1, 0.1)
    K = estimator(X)
    ax.plot(X, K, label=f'{name}')
    # Set other things
    ax.legend(loc='best')

filter_unprune = ['module.sub_mean.shifter', 'module.add_mean.shifter']
for name, module in model.named_modules():
    module_name = type(module).__name__
    unprune = 'Conv2d' not in module_name or name in filter_unprune
    if unprune: continue
    visualize_both(name, module.weight.sum(-1).sum(-1).sum(-1).detach().cpu())

In [None]:
import copy

def prune_with_synflow(model, origin_channel, prune_channel_list, filter_unprune, filter_in, filter_out):
    # Copy model
    model_new = copy.deepcopy(model)
    model_new.train()
    for prune_channel in prune_channel_list:
        # Variable
        amount = prune_channel / origin_channel
        logger.debug("=" * 20)
        logger.debug("=" * 20)
        logger.debug(f"{origin_channel} to {origin_channel - prune_channel} with amount: {amount}")
        logger.debug("=" * 20)
        logger.debug("=" * 20)
        origin_channel -= prune_channel
        # Get absolute model
        model_abs = copy.deepcopy(model_new)
        for name, module in model_abs.named_modules():
            # absolute or not
            module_name = type(module).__name__
            unabsolute = 'Conv2d' not in module_name
            if unabsolute: continue
            # absolute
            module.weight = torch.nn.Parameter(module.weight.abs() * 1e-2)
            module.bias = torch.nn.Parameter(torch.zeros_like(module.bias))
        # SynFlow
        model_abs.zero_grad()
        img = torch.ones((1, 3, 32, 32))
        out = model_abs(img)
        loss = out.sum()
        loss.backward()
        # do pruning
        for name, module in model_new.named_modules():
            # Prune or not
            module_name = type(module).__name__
            unprune = 'Conv2d' not in module_name or name in filter_unprune
            logger.debug("=" * 20)
            logger.debug(f"{name}, {module_name}: Prune-{not unprune}")
            if unprune: continue
            # Get absolute module with gradient
            module_abs = None
            for n, m in model_abs.named_modules():
                if name == n:
                    module_abs = m
                    break
            # Prepare pruning
            pre_module_weight_shape = module.weight.shape
            pre_module_bias_shape = module.bias.shape
            # SynFlow score
            module_abs.weight = torch.nn.Parameter(torch.mul(module_abs.weight, module_abs.weight.grad))
            # Prune in_channel
            if name not in filter_in:
                prune.ln_structured(module_abs, 'weight', amount=amount, n=2, dim=1)
                prune.remove(module_abs, 'weight')
                in_channel_mask = module_abs.weight.sum(-1).sum(-1).sum(0) != 0
            else:
                in_channel_mask = torch.ones(module.weight.shape[1], dtype=torch.bool).to(device)
            # Prune out_channel
            if name not in filter_out:
                prune.ln_structured(module_abs, 'weight', amount=amount, n=2, dim=0)
                prune.remove(module_abs, 'weight')
                out_channel_mask = module_abs.weight.sum(-1).sum(-1).sum(-1) != 0
            else:
                out_channel_mask = torch.ones(module.weight.shape[0], dtype=torch.bool).to(device)
            module.weight = torch.nn.Parameter(module.weight[:,in_channel_mask,:,:])
            module.weight = torch.nn.Parameter(module.weight[out_channel_mask,:,:,:])
            module.bias = torch.nn.Parameter(module.bias[out_channel_mask])
            logger.debug(f"mask_index.shape: {out_channel_mask.shape}")
            logger.debug(f"module.weight.shape: {pre_module_weight_shape} --> {module.weight.shape}")
            logger.debug(f"module.bias.shape: {pre_module_bias_shape} --> {module.bias.shape}")
    return model_new

In [None]:
import copy
import torch

def prune_with_srflow(model, origin_channel, prune_channel_list, filter_unprune, filter_in, filter_out):
    # Copy model
    model_new = copy.deepcopy(model)
    model_new.train()
    for prune_channel in prune_channel_list:
        # Variable
        amount = prune_channel / origin_channel
        logger.debug("=" * 20)
        logger.debug("=" * 20)
        logger.debug(f"{origin_channel} to {origin_channel - prune_channel} with amount: {amount}")
        logger.debug("=" * 20)
        logger.debug("=" * 20)
        origin_channel -= prune_channel
        # SRFlow
        model_flow = copy.deepcopy(model_new)
        model_flow.zero_grad()
        img_lr = torch.ones((1, 3, 32, 32))
        img_hr = torch.ones((1, 3, 32 * 4, 32 * 4))
        img_sr = model_flow(img_lr)
        loss = torch.nn.MSELoss()(img_hr, img_sr)
        loss.backward()
        # do pruning
        for name, module in model_new.named_modules():
            # Prune or not
            module_name = type(module).__name__
            unprune = 'Conv2d' not in module_name or name in filter_unprune
            logger.debug("=" * 20)
            logger.debug(f"{name}, {module_name}: Prune-{not unprune}")
            if unprune: continue
            # Get flow module with gradient
            module_flow = None
            for n, m in model_flow.named_modules():
                if name == n:
                    module_flow = m
                    break
            # Prepare pruning
            pre_module_weight_shape = module.weight.shape
            pre_module_bias_shape = module.bias.shape
            # SRFlow score
            module_flow.weight = torch.nn.Parameter(torch.mul(module_flow.weight.abs(), module_flow.weight.grad.abs()))
            # Prune in_channel
            if name not in filter_in:
                prune.ln_structured(module_flow, 'weight', amount=amount, n=2, dim=1)
                prune.remove(module_flow, 'weight')
                in_channel_mask = module_flow.weight.sum(-1).sum(-1).sum(0) != 0
            else:
                in_channel_mask = torch.ones(module.weight.shape[1], dtype=torch.bool).to(device)
            # Prune out_channel
            if name not in filter_out:
                prune.ln_structured(module_flow, 'weight', amount=amount, n=2, dim=0)
                prune.remove(module_flow, 'weight')
                out_channel_mask = module_flow.weight.sum(-1).sum(-1).sum(-1) != 0
            else:
                out_channel_mask = torch.ones(module.weight.shape[0], dtype=torch.bool).to(device)
            module.weight = torch.nn.Parameter(module.weight[:,in_channel_mask,:,:])
            module.weight = torch.nn.Parameter(module.weight[out_channel_mask,:,:,:])
            module.bias = torch.nn.Parameter(module.bias[out_channel_mask])
            logger.debug(f"mask_index.shape: {out_channel_mask.shape}")
            logger.debug(f"module.weight.shape: {pre_module_weight_shape} --> {module.weight.shape}")
            logger.debug(f"module.bias.shape: {pre_module_bias_shape} --> {module.bias.shape}")
    return model_new

In [None]:
import copy
import torch


def prune_with_srflow2(model, origin_channel, prune_channel_list, filter_unprune, filter_in, filter_out, device, epoch=50):
    # Copy model
    model_new = copy.deepcopy(model)
    model_new.train()
    # Image
    img_lr = torch.ones((1, 3, 32, 32)).to(device)
    img_hr = torch.ones((1, 3, 32 * 4, 32 * 4)).to(device)
    criterion = torch.nn.MSELoss()
    for i, prune_channel in enumerate(prune_channel_list):
        # Variable
        amount = prune_channel / origin_channel
        logger.debug("=" * 20)
        logger.debug("=" * 20)
        logger.debug(f"{origin_channel} to {origin_channel - prune_channel} with amount: {amount}")
        logger.debug("=" * 20)
        logger.debug("=" * 20)
        origin_channel -= prune_channel
        # Train model a little bit
        optimizer = torch.optim.Adam(model_new.parameters())
        for j in range(epoch):
            optimizer.zero_grad()
            img_sr = model_new(img_lr)
            loss = criterion(img_hr, img_sr)
            loss.backward()
            optimizer.step()
            print(f"[{i + 1}/{len(prune_channel_list)}][{j + 1}/{epoch}] loss: {loss}")
        # SRFlow
        model_flow = copy.deepcopy(model_new)
        model_flow.zero_grad()
        img_sr = model_flow(img_lr)
        loss = criterion(img_hr, img_sr)
        loss.backward()
        # do pruning
        for name, module in model_new.named_modules():
            # Prune or not
            module_name = type(module).__name__
            unprune = 'Conv2d' not in module_name or name in filter_unprune
            logger.debug("=" * 20)
            logger.debug(f"{name}, {module_name}: Prune-{not unprune}")
            if unprune: continue
            # Get flow module with gradient
            module_flow = None
            for n, m in model_flow.named_modules():
                if name == n:
                    module_flow = m
                    break
            # Prepare pruning
            pre_module_weight_shape = module.weight.shape
            pre_module_bias_shape = module.bias.shape
            # SRFlow score
            module_flow.weight = torch.nn.Parameter(torch.mul(module_flow.weight.abs(), module_flow.weight.grad.abs()))
            # Prune in_channel
            if name not in filter_in:
                prune.ln_structured(module_flow, 'weight', amount=amount, n=2, dim=1)
                prune.remove(module_flow, 'weight')
                in_channel_mask = module_flow.weight.sum(-1).sum(-1).sum(0) != 0
            else:
                in_channel_mask = torch.ones(module.weight.shape[1], dtype=torch.bool).to(device)
            # Prune out_channel
            if name not in filter_out:
                prune.ln_structured(module_flow, 'weight', amount=amount, n=2, dim=0)
                prune.remove(module_flow, 'weight')
                out_channel_mask = module_flow.weight.sum(-1).sum(-1).sum(-1) != 0
            else:
                out_channel_mask = torch.ones(module.weight.shape[0], dtype=torch.bool).to(device)
            module.weight = torch.nn.Parameter(module.weight[:,in_channel_mask,:,:])
            module.weight = torch.nn.Parameter(module.weight[out_channel_mask,:,:,:])
            module.bias = torch.nn.Parameter(module.bias[out_channel_mask])
            logger.debug(f"mask_index.shape: {out_channel_mask.shape}")
            logger.debug(f"module.weight.shape: {pre_module_weight_shape} --> {module.weight.shape}")
            logger.debug(f"module.bias.shape: {pre_module_bias_shape} --> {module.bias.shape}")
    # Last fine-tuning
    optimizer = torch.optim.Adam(model_new.parameters())
    for j in range(epoch):
        optimizer.zero_grad()
        img_sr = model_new(img_lr)
        loss = criterion(img_hr, img_sr)
        loss.backward()
        optimizer.step()
        print(f"[{j + 1}/{epoch}] loss: {loss}")
    return model_new

In [None]:
import copy
import torch

def prune_with_random(model, origin_channel, prune_channel_list, filter_unprune, filter_in, filter_out):
    # Copy model
    model_new = copy.deepcopy(model)
    model_new.train()
    for prune_channel in prune_channel_list:
        # Variable
        amount = prune_channel / origin_channel
        logger.debug("=" * 20)
        logger.debug("=" * 20)
        logger.debug(f"{origin_channel} to {origin_channel - prune_channel} with amount: {amount}")
        logger.debug("=" * 20)
        logger.debug("=" * 20)
        origin_channel -= prune_channel
        # do pruning
        for name, module in model_new.named_modules():
            # Prune or not
            module_name = type(module).__name__
            unprune = 'Conv2d' not in module_name or name in filter_unprune
            logger.debug("=" * 20)
            logger.debug(f"{name}, {module_name}: Prune-{not unprune}")
            if unprune: continue
            # Prepare pruning
            pre_module_weight_shape = module.weight.shape
            pre_module_bias_shape = module.bias.shape
            # Prune in_channel
            if name not in filter_in:
                prune.random_structured(module, 'weight', amount=amount, dim=1)
                prune.remove(module, 'weight')
                in_channel_mask = module.weight.sum(-1).sum(-1).sum(0) != 0
            else:
                in_channel_mask = torch.ones(module.weight.shape[1], dtype=torch.bool).to(device)
            # Prune out_channel
            if name not in filter_out:
                prune.random_structured(module, 'weight', amount=amount, dim=0)
                prune.remove(module, 'weight')
                out_channel_mask = module.weight.sum(-1).sum(-1).sum(-1) != 0
            else:
                out_channel_mask = torch.ones(module.weight.shape[0], dtype=torch.bool).to(device)
            module.weight = torch.nn.Parameter(module.weight[:,in_channel_mask,:,:])
            module.weight = torch.nn.Parameter(module.weight[out_channel_mask,:,:,:])
            module.bias = torch.nn.Parameter(module.bias[out_channel_mask])
            logger.debug(f"mask_index.shape: {out_channel_mask.shape}")
            logger.debug(f"module.weight.shape: {pre_module_weight_shape} --> {module.weight.shape}")
            logger.debug(f"module.bias.shape: {pre_module_bias_shape} --> {module.bias.shape}")
    return model_new

In [None]:
origin_channel = 64
prune_channel_list = [2] * (16)
compression_rate = ((origin_channel - sum(prune_channel_list)) / origin_channel) ** 2
filter_unprune = []
filter_in = ['conv_first']
filter_out = ['conv_last']

model = RRDBNet(3, 3, 64, 23, gc=32).to(device)
print(f'Params number(Before Pruning): {count_num_of_parameters(model)}')
print(f'compression_rate: {compression_rate}')
print(f'Params number(Predict): {count_num_of_parameters(model) * compression_rate}')
#pruned_model = prune_with_synflow(model, origin_channel, prune_channel_list, filter_unprune, filter_in, filter_out)
#pruned_model = prune_with_srflow(model, origin_channel, prune_channel_list, filter_unprune, filter_in, filter_out)
pruned_model = prune_with_srflow2(model, origin_channel, prune_channel_list, filter_unprune, filter_in, filter_out, device)
#pruned_model = prune_with_random(model, origin_channel, prune_channel_list, filter_unprune, filter_in, filter_out)
print(f'Params number(After Pruning): {count_num_of_parameters(pruned_model)}')

In [None]:
origin_channel = 64
prune_channel_list = [2] * (16 + 8)
compression_rate = ((origin_channel - sum(prune_channel_list)) / origin_channel) ** 2
filter_unprune = []
filter_in = ['sfe1']
filter_out = ['output']

model = RDN(scale_factor=4, num_channels=3, num_features=64, growth_rate=64, num_blocks=16, num_layers=8)
print(f'Params number(Before Pruning): {count_num_of_parameters(model)}')
print(f'compression_rate: {compression_rate}')
print(f'Params number(Predict): {count_num_of_parameters(model) * compression_rate}')
pruned_model = prune_with_synflow(model, origin_channel, prune_channel_list, filter_unprune, filter_in, filter_out)
print(f'Params number(After Pruning): {count_num_of_parameters(pruned_model)}')

In [None]:
# RCAN naive l2 channel pruning

rewind_model = load_model(rewind_path, device, model_name)
model_new = copy.deepcopy(model)
# Kind of filter
filter_unprune = ['sub_mean', 'add_mean']
filter_in = ['head.0']
filter_out = ['tail.1']
amount = 0.25
# Pruning!
mask_dict = OrderedDict()
logger.info('Params number(Before prune): {}'.format(count_num_of_parameters(model_new)))
for name, module in model_new.named_modules():
    # Prune or not
    module_name = type(module).__name__
    unprune = 'Conv2d' not in module_name or name in filter_unprune
    logger.debug("=" * 20)
    logger.debug(f"{name}, {module_name}: Prune-{not unprune}")
    if unprune: continue
    # Prepare pruning
    rewind_module = None
    for n, m in rewind_model.named_modules():
        if name == n:
            rewind_module = m
            break
    pre_module_weight_shape = module.weight.shape
    pre_module_bias_shape = module.bias.shape
    # Prune in_channel
    if name not in filter_in:
        temp_module = copy.deepcopy(module)
        prune.ln_structured(temp_module, 'weight', amount=amount, n=2, dim=1)
        prune.remove(temp_module, 'weight')
        in_channel_mask = temp_module.weight.sum(-1).sum(-1).sum(0) != 0
    else:
        in_channel_mask = torch.ones(module.weight.shape[1], dtype=torch.bool).to(device)
    # Prune out_channel
    if name not in filter_out:
        temp_module = copy.deepcopy(module)
        prune.ln_structured(temp_module, 'weight', amount=amount, n=2, dim=0)
        prune.remove(temp_module, 'weight')
        out_channel_mask = temp_module.weight.sum(-1).sum(-1).sum(-1) != 0
    else:
        out_channel_mask = torch.ones(module.weight.shape[0], dtype=torch.bool).to(device)
    module.weight = torch.nn.Parameter(rewind_module.weight[:, in_channel_mask])
    module.weight = torch.nn.Parameter(module.weight[out_channel_mask, :])
    module.bias = torch.nn.Parameter(rewind_module.bias[out_channel_mask])
    logger.debug(f"mask_index.shape: {out_channel_mask.shape}")
    logger.debug(f"module.weight.shape: {pre_module_weight_shape} --> {module.weight.shape}")
    logger.debug(f"module.bias.shape: {pre_module_bias_shape} --> {module.bias.shape}")
logger.info('Params number(After prune): {}'.format(count_num_of_parameters(model_new)))

In [None]:
# RRDB naive l2 channel pruning
import copy
from collections import OrderedDict

rewind_model = load_model(rewind_path, device, model_name)
model_new = copy.deepcopy(model)
# Kind of filter
filter_unprune = []
filter_in = ['conv_first']
filter_out = ['conv_last']
amount = 0.25
# Pruning!
mask_dict = OrderedDict()
logger.info('Params number(Before prune): {}'.format(count_num_of_parameters(model_new)))
for name, module in model_new.named_modules():
    # Prune or not
    module_name = type(module).__name__
    unprune = 'Conv2d' not in module_name or name in filter_unprune
    logger.debug("=" * 20)
    logger.debug(f"{name}, {module_name}: Prune-{not unprune}")
    if unprune: continue
    # Prepare pruning
    rewind_module = None
    for n, m in rewind_model.named_modules():
        if name == n:
            rewind_module = m
            break
    pre_module_weight_shape = module.weight.shape
    pre_module_bias_shape = module.bias.shape
    # Prune in_channel
    if name not in filter_in:
        temp_module = copy.deepcopy(module)
        prune.ln_structured(temp_module, 'weight', amount=amount, n=2, dim=1)
        prune.remove(temp_module, 'weight')
        in_channel_mask = temp_module.weight.sum(-1).sum(-1).sum(0) != 0
    else:
        in_channel_mask = torch.ones(module.weight.shape[1], dtype=torch.bool).to(device)
    # Prune out_channel
    if name not in filter_out:
        temp_module = copy.deepcopy(module)
        prune.ln_structured(temp_module, 'weight', amount=amount, n=2, dim=0)
        prune.remove(temp_module, 'weight')
        out_channel_mask = temp_module.weight.sum(-1).sum(-1).sum(-1) != 0
    else:
        out_channel_mask = torch.ones(module.weight.shape[0], dtype=torch.bool).to(device)
    module.weight = torch.nn.Parameter(rewind_module.weight[:, in_channel_mask])
    module.weight = torch.nn.Parameter(module.weight[out_channel_mask, :])
    module.bias = torch.nn.Parameter(rewind_module.bias[out_channel_mask])
    logger.debug(f"mask_index.shape: {out_channel_mask.shape}")
    logger.debug(f"module.weight.shape: {pre_module_weight_shape} --> {module.weight.shape}")
    logger.debug(f"module.bias.shape: {pre_module_bias_shape} --> {module.bias.shape}")
logger.info('Params number(After prune): {}'.format(count_num_of_parameters(model_new)))

In [None]:
# RDN naive l2 channel pruning
import copy
from collections import OrderedDict

rewind_model = load_model(rewind_path, device, model_name)
model_new = copy.deepcopy(model)
# Kind of filter
filter_unprune = []
filter_in = ['sfe1']
filter_out = ['output']
amount = 0.21428571428571428571428571428571
# Pruning!
mask_dict = OrderedDict()
logger.info('Params number(Before prune): {}'.format(count_num_of_parameters(model_new)))
for name, module in model_new.named_modules():
    # Prune or not
    module_name = type(module).__name__
    unprune = 'Conv2d' not in module_name or name in filter_unprune
    logger.debug("=" * 20)
    logger.debug(f"{name}, {module_name}: Prune-{not unprune}")
    if unprune: continue
    # Prepare pruning
    rewind_module = None
    for n, m in rewind_model.named_modules():
        if name == n:
            rewind_module = m
            break
    pre_module_weight_shape = module.weight.shape
    pre_module_bias_shape = module.bias.shape
    # Prune in_channel
    if name not in filter_in:
        temp_module = copy.deepcopy(module)
        prune.ln_structured(temp_module, 'weight', amount=amount, n=2, dim=1)
        prune.remove(temp_module, 'weight')
        in_channel_mask = temp_module.weight.sum(-1).sum(-1).sum(0) != 0
    else:
        in_channel_mask = torch.ones(module.weight.shape[1], dtype=torch.bool).to(device)
    # Prune out_channel
    if name not in filter_out:
        temp_module = copy.deepcopy(module)
        prune.ln_structured(temp_module, 'weight', amount=amount, n=2, dim=0)
        prune.remove(temp_module, 'weight')
        out_channel_mask = temp_module.weight.sum(-1).sum(-1).sum(-1) != 0
    else:
        out_channel_mask = torch.ones(module.weight.shape[0], dtype=torch.bool).to(device)
    module.weight = torch.nn.Parameter(rewind_module.weight[:, in_channel_mask])
    module.weight = torch.nn.Parameter(module.weight[out_channel_mask, :])
    module.bias = torch.nn.Parameter(rewind_module.bias[out_channel_mask])
    logger.debug(f"mask_index.shape: {out_channel_mask.shape}")
    logger.debug(f"module.weight.shape: {pre_module_weight_shape} --> {module.weight.shape}")
    logger.debug(f"module.bias.shape: {pre_module_bias_shape} --> {module.bias.shape}")
logger.info('Params number(After prune): {}'.format(count_num_of_parameters(model_new)))

In [None]:
# Prune CARN naively
import copy
from collections import OrderedDict

rewind_model = load_model(rewind_path, device, model_name)
rewind_model = rewind_model.module
model_new = copy.deepcopy(model)
model_new = model_new.module
# Kind of filter
filter_unprune = ['sub_mean.shifter', 'add_mean.shifter']
# Pruning!
mask_dict = OrderedDict()
logger.info('Params number(Before prune): {}'.format(count_num_of_parameters(model_new)))
for name, module in model_new.named_modules():
    # Prune or not
    module_name = type(module).__name__
    unprune = 'Conv2d' not in module_name or name in filter_unprune
    logger.debug("=" * 20)
    logger.debug(f"{name}, {module_name}: Prune-{not unprune}")
    if unprune: continue
    # Prepare pruning
    rewind_module = None
    for n, m in rewind_model.named_modules():
        if name == n:
            rewind_module = m
            break
    pre_module_weight_shape = module.weight.shape
    pre_module_bias_shape = module.bias.shape
    # Prune in_channel
    if name not in ['entry']:
        temp_module = copy.deepcopy(module)
        prune.ln_structured(temp_module, 'weight', amount=0.90625, n=2, dim=1)
        prune.remove(temp_module, 'weight')
        in_channel_mask = temp_module.weight.sum(-1).sum(-1).sum(0) != 0
    else:
        in_channel_mask = torch.ones(module.weight.shape[1], dtype=torch.bool).to(device)
    # Prune out_channel
    if name not in ['exit']:
        temp_module = copy.deepcopy(module)
        prune.ln_structured(temp_module, 'weight', amount=0.90625, n=2, dim=0)
        prune.remove(temp_module, 'weight')
        out_channel_mask = temp_module.weight.sum(-1).sum(-1).sum(-1) != 0
    else:
        out_channel_mask = torch.ones(module.weight.shape[0], dtype=torch.bool).to(device)
    module.weight = torch.nn.Parameter(rewind_module.weight[:, in_channel_mask])
    module.weight = torch.nn.Parameter(module.weight[out_channel_mask, :])
    module.bias = torch.nn.Parameter(rewind_module.bias[out_channel_mask])
    logger.debug(f"mask_index.shape: {out_channel_mask.shape}")
    logger.debug(f"module.weight.shape: {pre_module_weight_shape} --> {module.weight.shape}")
    logger.debug(f"module.bias.shape: {pre_module_bias_shape} --> {module.bias.shape}")
logger.info('Params number(After prune): {}'.format(count_num_of_parameters(model_new)))

In [None]:
# Prune MSRResNet

import copy

model_new = copy.deepcopy(model)
logger.info('Params number(Before prune): {}'.format(count_num_of_parameters(model_new)))
pre_mask_index = torch.ones(3, dtype=torch.bool).to(device)
conv_first_mask_index = None
for name, module in model_new.named_modules():
    if 'conv' in name:
        if name == 'upconv2':
            logger.debug("=" * 20)
            logger.debug(f"{name}: Unpruned")
            continue
        # DEBUG ----------------------------------------------------
        logger.debug("=" * 20)
        logger.debug(f"{name}: Pruned")
        pre_module_weight_shape = module.weight.shape
        pre_module_bias_shape = module.bias.shape
        # DEBUG ----------------------------------------------------
        # Prune in_channel
        if name == 'upconv1':
            logger.debug(f"conv_first_mask_index.shape: {conv_first_mask_index.shape}")
            module.weight = torch.nn.Parameter(module.weight[:, conv_first_mask_index])
            # DEBUG ----------------------------------------------------
            logger.debug(f"mask_index.shape: {mask_index.shape}")
            logger.debug(f"module.weight.shape: {pre_module_weight_shape} --> {module.weight.shape}")
            logger.debug(f"module.bias.shape: {pre_module_bias_shape} --> {module.bias.shape}")
            # DEBUG ----------------------------------------------------
            continue
        if name not in ['conv_first', 'HRconv']: # Without first
            logger.debug(f"pre_mask_index.shape: {pre_mask_index.shape}")
            module.weight = torch.nn.Parameter(module.weight[:, pre_mask_index])
        # Get pruning mask
        prune.ln_structured(module, 'weight', amount=0.9, n=2, dim=0)
        prune.remove(module, 'weight')
        mask_index = module.weight.sum(-1).sum(-1).sum(-1) != 0
        pre_mask_index = mask_index
        if name == 'conv_first':
            conv_first_mask_index = mask_index
        # Prune out_channel
        if name not in ['conv_last']: # Without last
            module.weight = torch.nn.Parameter(module.weight[mask_index, :])
            module.bias = torch.nn.Parameter(module.bias[mask_index])
            # DEBUG ----------------------------------------------------
            logger.debug(f"mask_index.shape: {mask_index.shape}")
            logger.debug(f"module.weight.shape: {pre_module_weight_shape} --> {module.weight.shape}")
            logger.debug(f"module.bias.shape: {pre_module_bias_shape} --> {module.bias.shape}")
            # DEBUG ----------------------------------------------------
    else:
        logger.debug("=" * 20)
        logger.debug(f"{name}: Unpruned")
logger.info('Params number(After prune): {}'.format(count_num_of_parameters(model_new)))

In [None]:
test(model_new, L_folder, P_folder, logger, True)
#test(model, L_folder, P_folder, logger, True)
logger.info(f'PSNR after pruning {calculate_psnr(H_folder, P_folder, logger)}')

In [None]:
# Image
img_lr = torch.ones((1, 3, 32, 32)).to(device)
img_hr = torch.ones((1, 3, 32 * 4, 32 * 4)).to(device)
criterion = torch.nn.MSELoss()
# Train model a little bit
optimizer = torch.optim.Adam(pruned_model.parameters())
for j in range(50):
    optimizer.zero_grad()
    img_sr = pruned_model(img_lr)
    loss = criterion(img_hr, img_sr)
    loss.backward()
    optimizer.step()
    print(f"[{j + 1}/{50}] loss: {loss}")

In [None]:
pruned_model_path = "model/RRDB_32_srflow2_finetune.pth"
torch.save({
    'net': pruned_model.state_dict()
}, pruned_model_path)