## Setup

In [29]:
#@title Gloabal Flags
#@markdown ---
EVAL_MODEL_SINGLE_CASE = False                   #@param {type:"boolean"}
EVAL_MODEL_EXTENDED_CASE = True                  #@param {type:"boolean"}

In [30]:
#@title Model's Hyper-Params
#@markdown ---
N_HF=35 #@param {type:"integer"}
N_HL=9  #@param {type:"integer"}
SIDELENGTH=256 #@param {type:"integer"}
DEVICE = "cpu" #@param ["cpu", "cuda", "gpu"]
BATCH_SIZE=1 #@param {type:"integer"}
MODEL_PATH='/content/model_final.pth' #@param {type:"string"}

In [31]:
NUMBER_TRIALS_RATE_PRUNING = 5 #@param {type:"integer"}
NUMBER_TRIALS_ABS_PRUNING = 5 #@param {type:"integer"}

### Libs

In [32]:
# Installing third party dependencies
print("Installing required libraries...")

old_requirements = '/content/tmp_requirements.txt'
!pip freeze > {old_requirements}
dependencies_list = "cmapy,sk-video,pytorch-model-summary,ConfigArgParse,tabulate,chart_studio".split(",")

with open(old_requirements) as f:
    old_requirements_list = f.read().split("\n")
    for a_req in dependencies_list:
        found_req = False
        for old_req in old_requirements_list:
            if old_req.startswith(a_req):
                print(f"{a_req} already installed!")
                found_req = True
                break
        if found_req is False:
            !pip install {a_req} -q
    pass
!rm -f {old_requirements}

Installing required libraries...
cmapy already installed!
sk-video already installed!
pytorch-model-summary already installed!
ConfigArgParse already installed!
tabulate already installed!


In [33]:
from __future__ import print_function
from __future__ import division

# --------------------------------------------- #
# Standard Library, plus some Third Party Libraries
# --------------------------------------------- #

DASH_TEMPLATES_LIST = ["plotly", "plotly_white", "plotly_dark", "ggplot2", "seaborn", "simple_white", "none"]

from PIL import Image
from functools import partial
from pprint import pprint
from tqdm import tqdm
from typing import Tuple, Union


import configargparse
import copy
import collections
import datetime
import functools
import h5py
import logging
import math
import os
import operator
import pickle
import random
import shutil
import sys
import re
import tabulate 
import time
# import visdom


from collections import OrderedDict
import matplotlib
import matplotlib.pyplot as plt
import numpy as np

# --------------------------------------------- #
# Data Science and Machine Learning Libraries
# --------------------------------------------- #
import matplotlib
import matplotlib.pyplot as plt
matplotlib.style.use('ggplot')

import numpy as np
import pandas as pd
import sklearn

from sklearn.model_selection import ParameterGrid
from sklearn.model_selection import train_test_split

# --------------------------------------------- #
# Torch
# --------------------------------------------- #
try:
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import torch.optim as optim
    from torch.utils.data import DataLoader, Dataset
    import torch.quantization
    import torch.nn.utils.prune as prune
except:
    print("torch not available!")
    pass


# --------------------------------------------- #
# Import: TorchVision
# --------------------------------------------- #
try:
    import torchvision
    from torchvision import datasets
    from torchvision import transforms
    from torchvision.transforms import Resize, Compose, ToTensor, CenterCrop, Normalize
    from torchvision.utils import save_image
except:
    print("torchvision library not available!")
    pass

# Plotly imports.
# ----------------------------------------------- #
import chart_studio.plotly as py
import plotly.figure_factory as ff
import plotly.express as px

# --------------------------------------------- #
# Import: skimage
# --------------------------------------------- #
try:
    import skimage
    import skimage.metrics as skmetrics
    from skimage.metrics import peak_signal_noise_ratio as psnr
    from skimage.metrics import structural_similarity as ssim
    from skimage.metrics import mean_squared_error
except:
    print("skimage library not available!")
    pass

### Functions

#### Utils

In [34]:
def set_seeds(seed):
    """Set seeds for torch, np.random and random std python library.
    Params
    ------
    `seed` - int object, seed for starting pseudo-random series.\n
    """
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    pass

In [35]:
def check_device_and_weigths_to_laod(model_fp32, device = 'cpu', model_path = None):
    """Check to which device load a PyTorch model and fi it is necessary to also laod weights for that model.
    Params
    ------
    `model_fp32` - PyTorch model to be fetched to a proper device .\n
    `device` - str object, kind of device upon which model's weigths and computation will be done. Allowed CPU,GPU,CUDA.\n
    `model_path` - str python object, either None when we do not desire to load weights, or a file path to model's weights.\n
    Return
    ------
    `model_fp32` - PyTorch model loaded to a proper device .\n
    """
    if device == 'cpu':
        # print('Load Model to cpu device!')
        model_fp32 = model_fp32.to('cpu')
        if model_path != None:
            # print('Load Model weigths!')
            state_dict = torch.load(model_path, map_location=torch.device('cpu'))
            model_fp32.load_state_dict(state_dict)
            pass
        pass
    else:
        try:
            model_fp32 = model_fp32.cuda()
            # print('Load Model to cuda device!')
            if model_path != None:
                # print('Load Model weigths!')
                state_dict = torch.load(model_path, map_location=torch.device('cuda'))
                model_fp32.load_state_dict(state_dict)
                pass
            pass
        except:
            model_fp32 = model_fp32.to('cpu')
            if model_path != None:
                # print('Load Model weigths!')
                state_dict = torch.load(model_path, map_location=torch.device('cpu'))
                model_fp32.load_state_dict(state_dict)
                pass
            pass
        pass
    return model_fp32

def get_size_of_model(model):
    """Return model size as file size corresponding to model's state dictionary when saved temporarily to 
    disk.
    Params
    ------
    `model` - PyTorch like model.\n
    Return
    ------
    `model_size` - int python object, size of state dictionary expressed in byte.\n
    """
    torch.save(model.state_dict(), "temp.p")
    # print('Size (MB):', os.path.getsize("temp.p")/1e6)
    model_size = os.path.getsize("temp.p")
    os.remove('temp.p')
    return model_size

def compute_desired_metrices(model_output, gt, data_range=1.):
    """Compute PSNR and SSIM scores.
    Params:
    -------
    `model_output` - output produced by a Pytorch model\n
    `gt` - reference data\n
    `data_range` - int, range of input data\n

    Return:
    -------
    `val_psnr, val_mssim` - scores from metrices PSNR, and SSIM
    """

    sidelenght = model_output.size()[1]

    arr_gt = gt.cpu().view(sidelenght).detach().numpy()
    arr_gt = (arr_gt / 2.) + 0.5

    arr_output = model_output.cpu().view(sidelenght).detach().numpy()
    arr_output = (arr_output / 2.) + 0.5
    arr_output = np.clip(arr_output, a_min=0., a_max=1.)

    val_psnr = psnr(arr_gt, arr_output,data_range=data_range)
    val_mssim = ssim(arr_gt, arr_output,data_range=data_range)
    return val_psnr, val_mssim

def get_data_ready_for_model(model_input, gt, quantization_enabled = None, device = 'cpu'):
    """Setup data to be feeded into the model, as the latter will expect.
    Params:
    -------
    `model_input` - input to be processed by PyTorch model\n
    `gt` - reference data\n
    `quantization_enabled` - str object, quantization technique name, allowed values: [dynamic,static,post_train,quantization_aware_training]\n
    `device` - str object, allowed values: 'cpu', 'gpu', 'cuda'\n

    Return:
    -------
    `model_input, gt` - data ready to be feeded into PyTorch model
    """
    if device == 'cpu':
        model_input = model_input['coords'].to('cpu')
        gt = gt['img'].to('cpu')
        if quantization_enabled  != None:
            pass
    else:
        model_input = model_input['coords'].cuda()
        gt = gt['img'].cuda()
        if quantization_enabled  != None:
            pass
        pass
    return model_input, gt

#### Data Loader

In [36]:
def get_mgrid(sidelen, dim=2):
    '''Generates a flattened grid of (x,y,...) coordinates in a range of -1 to 1.'''
    if isinstance(sidelen, int):
        sidelen = dim * (sidelen,)

    if dim == 2:
        pixel_coords = np.stack(np.mgrid[:sidelen[0], :sidelen[1]], axis=-1)[None, ...].astype(np.float32)
        pixel_coords[0, :, :, 0] = pixel_coords[0, :, :, 0] / (sidelen[0] - 1)
        pixel_coords[0, :, :, 1] = pixel_coords[0, :, :, 1] / (sidelen[1] - 1)
    elif dim == 3:
        pixel_coords = np.stack(np.mgrid[:sidelen[0], :sidelen[1], :sidelen[2]], axis=-1)[None, ...].astype(np.float32)
        pixel_coords[..., 0] = pixel_coords[..., 0] / max(sidelen[0] - 1, 1)
        pixel_coords[..., 1] = pixel_coords[..., 1] / (sidelen[1] - 1)
        pixel_coords[..., 2] = pixel_coords[..., 2] / (sidelen[2] - 1)
    else:
        raise NotImplementedError('Not implemented for dim=%d' % dim)

    pixel_coords -= 0.5
    pixel_coords *= 2.
    pixel_coords = torch.Tensor(pixel_coords).view(-1, dim)
    return pixel_coords


def get_input_image(opt):
    """Get input image, if none image is provided, then Cameramen default image will be fetched."""
    if opt.image_filepath is None:
        img_dataset = Camera()
        img = Image.fromarray(skimage.data.camera())
        image_resolution = img.size
        if opt.sidelength is None:
            opt.sidelength = image_resolution
            # opt.sidelength = 256
            pass
    else:
        img_dataset =  ImageFile(opt.image_filepath)
        img = Image.open(opt.image_filepath)
        image_resolution = img.size
        if opt.sidelength is None:
            opt.sidelength = image_resolution
            # opt.sidelength = image_resolution
            pass
        pass

    return img_dataset, img, image_resolution

class Camera(Dataset):
    def __init__(self, downsample_factor=1):
        super().__init__()
        self.downsample_factor = downsample_factor
        self.img = Image.fromarray(skimage.data.camera())
        self.img_channels = 1

        if downsample_factor > 1:
            size = (int(512 / downsample_factor),) * 2
            self.img_downsampled = self.img.resize(size, Image.ANTIALIAS)

    def __len__(self):
        return 1

    def __getitem__(self, idx):
        if self.downsample_factor > 1:
            return self.img_downsampled
        else:
            return self.img

class ImageFile(Dataset):
    def __init__(self, filename):
        super().__init__()
        self.img = Image.open(filename)
        self.img_channels = len(self.img.mode)

    def __len__(self):
        return 1

    def __getitem__(self, idx):
        return self.img
    pass

class Implicit2DWrapper(torch.utils.data.Dataset):
    def __init__(self, dataset, sidelength=None, compute_diff=None):

        if isinstance(sidelength, int):
            sidelength = (sidelength, sidelength)
        self.sidelength = sidelength

        self.transform = Compose([
            # Resize(sidelength),
            CenterCrop(sidelength),
            ToTensor(),
            Normalize(torch.Tensor([0.5]), torch.Tensor([0.5]))
        ])

        self.compute_diff = compute_diff
        self.dataset = dataset
        self.mgrid = get_mgrid(sidelength)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        img = self.transform(self.dataset[idx])

        if self.compute_diff == 'gradients':
            img *= 1e1
            gradx = scipy.ndimage.sobel(img.numpy(), axis=1).squeeze(0)[..., None]
            grady = scipy.ndimage.sobel(img.numpy(), axis=2).squeeze(0)[..., None]
        elif self.compute_diff == 'laplacian':
            img *= 1e4
            laplace = scipy.ndimage.laplace(img.numpy()).squeeze(0)[..., None]
        elif self.compute_diff == 'all':
            gradx = scipy.ndimage.sobel(img.numpy(), axis=1).squeeze(0)[..., None]
            grady = scipy.ndimage.sobel(img.numpy(), axis=2).squeeze(0)[..., None]
            laplace = scipy.ndimage.laplace(img.numpy()).squeeze(0)[..., None]

        img = img.permute(1, 2, 0).view(-1, self.dataset.img_channels)

        in_dict = {'idx': idx, 'coords': self.mgrid}
        gt_dict = {'img': img}

        if self.compute_diff == 'gradients':
            gradients = torch.cat((torch.from_numpy(gradx).reshape(-1, 1),
                                   torch.from_numpy(grady).reshape(-1, 1)),
                                  dim=-1)
            gt_dict.update({'gradients': gradients})

        elif self.compute_diff == 'laplacian':
            gt_dict.update({'laplace': torch.from_numpy(laplace).view(-1, 1)})

        elif self.compute_diff == 'all':
            gradients = torch.cat((torch.from_numpy(gradx).reshape(-1, 1),
                                   torch.from_numpy(grady).reshape(-1, 1)),
                                  dim=-1)
            gt_dict.update({'gradients': gradients})
            gt_dict.update({'laplace': torch.from_numpy(laplace).view(-1, 1)})

        return in_dict, gt_dict
    pass

def get_data_for_train(img_dataset, sidelength, batch_size):
    """Get data ready to be feed into a DNN model as input data for training and evaluating phase, respectively.
    Params
    ------
    `img_dataset` - PyTorch's DataSet like object representing the data against which evaluate models(base model and quantized models, if any).\n
    `sidelength` - eithr int object or lsit,tuple, representing width and height for center cropping input image.\n
    `batch_size` - int object for dividing input data into several batches.\n
    Return
    ------
    `train_dataloader` - PyTorch DataLoader instance.\n
    `val_dataloader` - PyTorch DataLoader instance.\n
    """
    coord_dataset = Implicit2DWrapper(
        img_dataset, sidelength=sidelength, compute_diff=None)

    # --- Prepare dataloaders for train and eval phases.
    train_dataloader = DataLoader(
        coord_dataset,
        shuffle=True,
        batch_size=batch_size,
        pin_memory=True, num_workers=0)

    val_dataloader = DataLoader(
        coord_dataset,
        shuffle=False,
        batch_size=batch_size,
        pin_memory=True, num_workers=0)
    
    return train_dataloader, val_dataloader

#### Evalaute Functions

In [37]:
def evaluate_model(model, eval_dataloader, device = 'cpu', loss_fn = nn.MSELoss(), quantization_enabled = None, verbose = 0, logging_flag = False, tqdm = None):
    """Evaluate model, computing: loss score, PSNR and MSSI metrices, when model swithced in eval mode..
    Params
    ------
    `model` - PyTorch based model\n
    `eval_dataloader` - PyTorch DataLoader like object\n
    `device` - str object, allowed values: 'cpu', 'gpu', 'cuda'\n
    `loss_fn` - Pytorch like Loss Function object\n
    `quantization_enabled` - str object, quantization technique name, allowed values: [dynamic,static,posterior,quantization_aware_training]\n
    `verbose` - int, verobose mode allowed values: 0,1,2.\n
    `logging_flag` - bool, if True enabling logging info for elapsed time in evaluating input data, result logged in info level.\n
    `tqdm` - tqdm object, if not None write to it.\n

    Return
    ------
    `eval_scores` - np.array object, containing loss, psnr, mssi scores compute when model swithced in eval mode.\n
    `eta_eval` - float representing time necessary for evaluating the model against the provided input.\n
    """
    eval_scores = None # Define a priori to use later.
    model.eval()
    with torch.no_grad():
        # -- Get data from validation loader.
        eval_input, eval_gt = next(iter(eval_dataloader))
        eval_input, eval_gt = \
            get_data_ready_for_model(model_input = eval_input,
            gt = eval_gt,
            quantization_enabled=quantization_enabled,
            device = device)
        
        # --- Compute estimation.
        start_time = time.time()
        eval_output, _ = model(eval_input)
        eta_eval = time.time() - start_time

        # --- Prepare data for calculating metrices scores.
        # sidelenght = int(math.sqrt(val_output.size()[1]))
        eval_loss = loss_fn(eval_output, eval_gt)
        eval_psnr, eval_mssim = compute_desired_metrices(
            model_output = eval_output,
            gt = eval_gt)
        
        # --- Record results.
        # train_scores = np.array([[train_loss, val_psnr, val_mssim]])
        eval_scores = np.array([eval_loss.item(), eval_psnr, eval_mssim])
        pass
    return eval_scores, eta_eval


def evaluate_model_wrapper(model, opt, img_dataset, model_name, model_weight_path = None, logging=None, tqdm=None, verbose=0):
    """Evaluate model after training.
    Params
    ------
    `model` - PyTorch like object representing a Neural Network model.\n
    `opt` - Namespace python like object with attributes necessary to run the evaluation tasks required.\n
    `img_dataset` - PyTorch's DataSet like object representing the data against which evaluate models(base model and quantized models, if any).\n
    `model_name` - str like object, representing a identifier with which referring to the current trial to be evaluated.\n
    `model_weight_path` - str like object representing local file path for model's weights to be exploited when evaluating quantized models.\n
    `logging` - logging python's std library object for logging reasons to a log file.\n
    `tqdm` - tqdm instance for logging data to stdout keeping order with which informations are displayed.\n
    `verbose` - int python object, for deciding verbose strategy, available options: 0 = no info displayed to tqdm, 1 = info displayed to tqdm object.\n
    Return
    ------
    `eval_info_list` - python list object containing collections.namedtuple instances with results from different evaluations.\n
    """

    eval_dataloader, _ = \
        get_data_for_train(img_dataset, sidelength=opt.sidelength, batch_size=opt.batch_size)

    eval_field_names = "model_name,model_type,mse,psnr_db,ssim,eta_seconds,footprint_byte,footprint_percent".split(",")
    EvalInfos = collections.namedtuple("EvalInfos", eval_field_names)
    eval_info_list = []

    # tot_weights_model = sum(p.numel() for p in model.parameters())
    eval_scores, eta_eval = \
        evaluate_model(
            model=model,
            eval_dataloader=eval_dataloader,
            device='cpu')
    
    basic_model_size = get_size_of_model(model)
    # eval_info = EvalInfos._make(['Basic'] + list(eval_scores) + [eta_eval, tot_weights_model * 4, 100.0])
    eval_info = EvalInfos._make([model_name, opt.model_type] + list(eval_scores) + [eta_eval, basic_model_size, 100.0])
    eval_info_list.append(eval_info)
    """
    if opt.dynamic_quant != []:
        for a_dynamic_type in opt.dynamic_quant:
            eval_scores, eta_eval, model_size = \
                _evaluate_dynamic_quant(
                    opt,
                    dtype=a_dynamic_type,
                    img_dataset=img_dataset,
                    model = copy.deepcopy(model),
                    model_weight_path = model_weight_path,
                    device = 'cpu',
                    qconfig = 'fbgemm')
            eval_info = EvalInfos._make([model_name, f'Quant-{str(a_dynamic_type)}'] + list(eval_scores) + [eta_eval, model_size, model_size / basic_model_size * 100])
            eval_info_list.append(eval_info)
            pass
        pass
    
    table_vals = list(map(operator.methodcaller("values"), map(operator.methodcaller("_asdict"), eval_info_list)))
    table = tabulate.tabulate(table_vals, headers=eval_field_names)
    _log_infos(info_msg = f"{table}", header_msg = None, logging=logging, tqdm=tqdm, verbose=verbose)
    """
    return eval_info_list

### Models

In [38]:


class SineLayer(nn.Module):
    # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of omega_0.
    
    # If is_first=True, omega_0 is a frequency factor which simply multiplies the activations before the 
    # nonlinearity. Different signals may require different omega_0 in the first layer - this is a 
    # hyperparameter.
    
    # If is_first=False, then the weights will be divided by omega_0 so as to keep the magnitude of 
    # activations constant, but boost gradients to the weight matrix (see supplement Sec. 1.5)
    
    def __init__(self, in_features, out_features, bias=True,
                 is_first=False, omega_0=30):
        super().__init__()
        self.omega_0 = omega_0
        self.is_first = is_first
        
        self.in_features = in_features
        self.linear = nn.Linear(in_features, out_features, bias=bias)
        
        self.init_weights()
        pass
    
    def init_weights(self):
        with torch.no_grad():
            if self.is_first:
                self.linear.weight.uniform_(-1 / self.in_features, 
                                             1 / self.in_features)      
            else:
                self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0, 
                                             np.sqrt(6 / self.in_features) / self.omega_0)
        pass
        
    def forward(self, input):
        return torch.sin(self.omega_0 * self.linear(input))
    
    def forward_with_intermediate(self, input): 
        # For visualization of activation distributions
        intermediate = self.omega_0 * self.linear(input)
        return torch.sin(intermediate), intermediate
    pass
    
    
class Siren(nn.Module):
    def __init__(self, in_features, hidden_features, hidden_layers, out_features, outermost_linear=False, 
                 first_omega_0=30, hidden_omega_0=30.):
        super().__init__()
        
        self.net = []
        self.net.append(SineLayer(in_features, hidden_features, 
                                  is_first=True, omega_0=first_omega_0))

        for i in range(hidden_layers):
            self.net.append(SineLayer(hidden_features, hidden_features, 
                                      is_first=False, omega_0=hidden_omega_0))

        if outermost_linear:
            final_linear = nn.Linear(hidden_features, out_features)
            
            with torch.no_grad():
                final_linear.weight.uniform_(-np.sqrt(6 / hidden_features) / hidden_omega_0, 
                                              np.sqrt(6 / hidden_features) / hidden_omega_0)
                
            self.net.append(final_linear)
        else:
            self.net.append(SineLayer(hidden_features, out_features, 
                                      is_first=False, omega_0=hidden_omega_0))
        
        self.net = nn.Sequential(*self.net)
        pass
    
    def forward(self, coords):
        coords = coords.clone().detach().requires_grad_(True) # allows to take derivative w.r.t. input
        output = self.net(coords)
        return output, coords        

    def forward_with_activations(self, coords, retain_grad=False):
        '''Returns not only model output, but also intermediate activations.
        Only used for visualizing activations later!'''
        activations = OrderedDict()

        activation_count = 0
        x = coords.clone().detach().requires_grad_(True)
        activations['input'] = x
        for i, layer in enumerate(self.net):
            if isinstance(layer, SineLayer):
                x, intermed = layer.forward_with_intermediate(x)
                
                if retain_grad:
                    x.retain_grad()
                    intermed.retain_grad()
                    
                activations['_'.join((str(layer.__class__), "%d" % activation_count))] = intermed
                activation_count += 1
            else: 
                x = layer(x)
                
                if retain_grad:
                    x.retain_grad()
                    
            activations['_'.join((str(layer.__class__), "%d" % activation_count))] = x
            activation_count += 1

        return activations
    pass


### Prune Routine

In [39]:
def get_params_to_prune(model, module_set = {torch.nn.Linear}) -> tuple:
    """Get params to be pruned
    Params
    ------
    `model` - torch.nn.Module, DNN architecture.\n
    `module_set` - set of torch.nn.* to be pruned.\n
    Return
    ------
    `parameters_to_prune` - tuple of pairs (torch.nn.Module, 'weight').\n
    """
    parameters_to_prune = list()
    for name, module in model.named_modules():
        for a_kind_module in module_set:
            if isinstance(module, a_kind_module):
                parameters_to_prune.append([module, 'weight'])
    parameters_to_prune = tuple(map(tuple, parameters_to_prune))
    return parameters_to_prune


def remove_to_prune(model, module_set = {torch.nn.Linear}) -> tuple:
    """Apply prune.remove(<module>, <attribute name>) to proper modules within model.
    Params
    ------
    `model` - torch.nn.Module, DNN architecture.\n
    `module_set` - set of torch.nn.* to be pruned.\n
    Return
    ------
    `model` - torch.nn.Module, DNN architecture, updated.\n
    """
    parameters_to_prune = list()
    for name, module in model.named_modules():
        for a_kind_module in module_set:
            if isinstance(module, a_kind_module):
                prune.remove(module, 'weigth')
    return model


def show_model_sparsity(model: torch.nn.Module) -> None:
    """Show PyTorch's model sparsity.
    Params
    ------
    `model` - torch.nn.Module, DNN architecture.
    """
    zero_elemenets, n_elements = 0, 0
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            print("Local sparsity({}): {:.2f}%".format(name, 100. * float(torch.sum(module.weight == 0)) / float(module.weight.nelement())))
            zero_elemenets += torch.sum(module.weight == 0)
            n_elements += module.weight.nelement()
    print("Global sparsity: {:.2f}%".format(100. * float(zero_elemenets)/ float(n_elements)))
    pass


def compute_pruning_evaluation(
    model: torch.nn.Module,
    image_dataset,
    amount: float = 0.2,
    pruning_method = prune.L1Unstructured,
    number_trials: int = 10, device: str = 'cpu') -> list:
    """Compute pruning compression technique on a given model a given number of times.
    Params
    ------
    `model` - torch.nn.Module.\n
    `image_dataset` - PyTorch Dataset.\n
    `amount` - float percentage of weigths to be pruned randomly from each layer.\n
    `pruning_method` - pruning technique to be adopted.\n
    `number_trials` - number of times repeating the calculation.\n
    `device` - str = 'cpu'.\n
    Return
    ------
    `eval_info_list` - list object containing results.\n
    """
    eval_info_list = []
    name_pruning_method = str(pruning_method).split(" ")[1].split(".")[-1].replace('>', '').replace("'","")
    for trial_no in range(number_trials):
        set_seeds(seed = trial_no)
        model_copied = copy.deepcopy(model)
        parameters_to_prune = get_params_to_prune(model_copied, module_set = {torch.nn.Linear})
        prune.global_unstructured(
            parameters_to_prune,
            pruning_method=pruning_method,
            amount=amount,
        )
        # prune.remove(module, 'weight')
        # model_copied = remove_to_prune(model = model_copied)
        if isinstance(amount, int):
            model_type = f'{name_pruning_method}_{amount:.0f}'
        else:
            model_type = f'{name_pruning_method}_{amount:.2f}'
        
        arch_hyperparams = collections.OrderedDict(
            n_hf=N_HF,
            n_hl=N_HL,
            sidelength=SIDELENGTH,
            device=f'{device}',
            batch_size=BATCH_SIZE,
            model_type=f'{model_type}',
            model_path=f'{MODEL_PATH}',
            image_filepath=None,
        )
        opt = OptionModel._make(arch_hyperparams.values())
        res_evaluation = evaluate_model_wrapper(
            model = model_copied,
            opt = opt,
            img_dataset = image_dataset,
            model_name = f'model.{opt.n_hf}.{opt.n_hl}.{trial_no}',
            model_weight_path = None,
            logging=None,
            tqdm=None,
            verbose=0)

        eval_info_list.extend(res_evaluation)
        pass
    return eval_info_list

## Tests

In [40]:
arch_hyperparams = collections.OrderedDict(
    n_hf=N_HF,
    n_hl=N_HL,
    sidelength=SIDELENGTH,
    device=DEVICE,
    batch_size=BATCH_SIZE,
    model_type='Basic',
    model_path=f'{MODEL_PATH}',
    image_filepath=None,
)

In [41]:
OptionModel = collections.namedtuple('OptionModel', arch_hyperparams.keys())
opt = OptionModel._make(arch_hyperparams.values())
opt

OptionModel(n_hf=35, n_hl=9, sidelength=256, device='cpu', batch_size=1, model_type='Basic', model_path='/content/model_final.pth', image_filepath=None)

In [42]:
model = Siren(
    in_features=2,
    out_features=1,
    hidden_features=int(arch_hyperparams['n_hf']),
    hidden_layers=int(arch_hyperparams['n_hl']),
    outermost_linear=True)

In [43]:
model = check_device_and_weigths_to_laod(model_fp32 = model, device = 'cpu', model_path = opt.model_path)

In [44]:
print('Number of Hidden layers:', len(model.net[1:-1]))
print('Number of Overall layers:', len(model.net))

Number of Hidden layers: 9
Number of Overall layers: 11


In [45]:
# print(list(module.named_buffers()))

In [46]:
image_dataset, _, _ = get_input_image(opt)

In [47]:
eval_info_list = evaluate_model_wrapper(
    model = model,
    opt = opt,
    img_dataset = image_dataset,
    model_name = f'model-{opt.n_hf}-{opt.n_hl}',
    model_weight_path = None,
    logging=None,
    tqdm=None,
    verbose=0)

In [48]:
eval_field_names = "model_name,model_type,mse,psnr_db,ssim,eta_seconds,footprint_byte,footprint_percent".split(",")
table_vals = list(map(operator.methodcaller("values"), map(operator.methodcaller("_asdict"), eval_info_list)))
table = tabulate.tabulate(table_vals, headers=eval_field_names)
print(table)

model_name    model_type            mse    psnr_db      ssim    eta_seconds    footprint_byte    footprint_percent
------------  ------------  -----------  ---------  --------  -------------  ----------------  -------------------
model-35-9    Basic         0.000225895    42.4817  0.980631       0.277036             52657                  100


In [49]:
parameters_to_prune = ()
if EVAL_MODEL_SINGLE_CASE:
    model_2 = copy.deepcopy(model)
    for name, module in model.named_modules():
        print(name)
        if isinstance(module, torch.nn.Linear):
            pass
        pass
    pass

In [50]:
parameters_to_prune = list()
if EVAL_MODEL_SINGLE_CASE:
    model_2 = copy.deepcopy(model)
    for name, module in model_2.named_modules():
        if isinstance(module, torch.nn.Linear):
            print(name)
            parameters_to_prune.append([module, 'weight'])
            pass
        pass
    parameters_to_prune = tuple(map(tuple, parameters_to_prune))
    pass

In [51]:
if EVAL_MODEL_SINGLE_CASE:
    prune.global_unstructured(
        parameters_to_prune,
        pruning_method=prune.L1Unstructured,
        amount=0.2,
    )
    pass

In [52]:
zero_elemenets, n_elements = 0, 0
if EVAL_MODEL_SINGLE_CASE:
    for name, module in model_2.named_modules():
        if isinstance(module, torch.nn.Linear):
            print("Local sparsity({}): {:.2f}%".format(name, 100. * float(torch.sum(module.weight == 0)) / float(module.weight.nelement())))
            zero_elemenets += torch.sum(module.weight == 0)
            n_elements += module.weight.nelement()
            pass
        pass
    print("Global sparsity: {:.2f}%".format(100. * float(zero_elemenets)/ float(n_elements)))
    pass

In [53]:
if EVAL_MODEL_SINGLE_CASE:
    arch_hyperparams = collections.OrderedDict(
        n_hf=N_HF,
        n_hl=N_HL,
        sidelength=SIDELENGTH,
        device=DEVICE,
        batch_size=BATCH_SIZE,
        model_type='prune_0.20',
        model_path=f'{MODEL_PATH}',
        image_filepath=None,
    )
    opt = OptionModel._make(arch_hyperparams.values())
    res = evaluate_model_wrapper(
        model = model_2,
        opt = opt,
        img_dataset = image_dataset,
        model_name = f'model-{opt.n_hf}-{opt.n_hl}',
        model_weight_path = None,
        logging=None,
        tqdm=None,
        verbose=0)

    eval_info_list.extend(res)
    pass

In [54]:
if EVAL_MODEL_SINGLE_CASE:
    eval_field_names = "model_name,model_type,mse,psnr_db,ssim,eta_seconds,footprint_byte,footprint_percent".split(",")
    table_vals = list(map(operator.methodcaller("values"), map(operator.methodcaller("_asdict"), eval_info_list)))
    table = tabulate.tabulate(table_vals, headers=eval_field_names)
    print(table)
    pass

In [55]:
trials_rate = list(map(float, ".01,.02,.03,0.04,.05,.1,.2,.3,.4,.5,.6,.7,.8,.9".split(",")))
trials_absolute = list(map(int, "4,5,8,10,12,16,20,30,32,40,50".split(",")))

In [56]:
for a_rate in trials_rate:
    res = compute_pruning_evaluation(
        model = model,
        image_dataset = image_dataset,
        amount = a_rate,
        pruning_method = prune.L1Unstructured,
        number_trials = NUMBER_TRIALS_RATE_PRUNING)
    eval_info_list.extend(res)
    pass

eval_field_names = "model_name,model_type,mse,psnr_db,ssim,eta_seconds,footprint_byte,footprint_percent".split(",")
table_vals = list(map(operator.methodcaller("values"), map(operator.methodcaller("_asdict"), res)))
table = tabulate.tabulate(table_vals[:5], headers=eval_field_names)
print(table)

model_name    model_type                mse    psnr_db      ssim    eta_seconds    footprint_byte    footprint_percent
------------  -------------------  --------  ---------  --------  -------------  ----------------  -------------------
model.35.9.0  L1Unstructured_0.90  0.391501    10.0933  0.343753       0.23068             100242                  100
model.35.9.1  L1Unstructured_0.90  0.391501    10.0933  0.343753       0.229137            100242                  100
model.35.9.2  L1Unstructured_0.90  0.391501    10.0933  0.343753       0.230276            100242                  100
model.35.9.3  L1Unstructured_0.90  0.391501    10.0933  0.343753       0.243162            100242                  100
model.35.9.4  L1Unstructured_0.90  0.391501    10.0933  0.343753       0.22999             100242                  100


In [57]:
for a_rate in trials_absolute:
    res = compute_pruning_evaluation(
        model = model,
        image_dataset = image_dataset,
        amount = a_rate,
        pruning_method = prune.L1Unstructured,
        number_trials = NUMBER_TRIALS_ABS_PRUNING)
    eval_info_list.extend(res)
    pass

eval_field_names = "model_name,model_type,mse,psnr_db,ssim,eta_seconds,footprint_byte,footprint_percent".split(",")
table_vals = list(map(operator.methodcaller("values"), map(operator.methodcaller("_asdict"), res)))
table = tabulate.tabulate(table_vals[:5], headers=eval_field_names)
print(table)

model_name    model_type                 mse    psnr_db      ssim    eta_seconds    footprint_byte    footprint_percent
------------  -----------------  -----------  ---------  --------  -------------  ----------------  -------------------
model.35.9.0  L1Unstructured_50  0.000228027    42.4409  0.980586       0.229404            100242                  100
model.35.9.1  L1Unstructured_50  0.000228027    42.4409  0.980586       0.231828            100242                  100
model.35.9.2  L1Unstructured_50  0.000228027    42.4409  0.980586       0.229989            100242                  100
model.35.9.3  L1Unstructured_50  0.000228027    42.4409  0.980586       0.233927            100242                  100
model.35.9.4  L1Unstructured_50  0.000228027    42.4409  0.980586       0.228276            100242                  100


In [58]:
for a_rate in trials_rate:
    res = compute_pruning_evaluation(
        model = model,
        image_dataset = image_dataset,
        amount = a_rate,
        pruning_method = prune.RandomUnstructured,
        number_trials = NUMBER_TRIALS_RATE_PRUNING)
    eval_info_list.extend(res)
    pass

eval_field_names = "model_name,model_type,mse,psnr_db,ssim,eta_seconds,footprint_byte,footprint_percent".split(",")
table_vals = list(map(operator.methodcaller("values"), map(operator.methodcaller("_asdict"), res)))
table = tabulate.tabulate(table_vals[:5], headers=eval_field_names)
print(table)

model_name    model_type                    mse    psnr_db      ssim    eta_seconds    footprint_byte    footprint_percent
------------  -----------------------  --------  ---------  --------  -------------  ----------------  -------------------
model.35.9.0  RandomUnstructured_0.90  0.402649    9.97134  0.343485       0.228324            100242                  100
model.35.9.1  RandomUnstructured_0.90  0.378515   10.2398   0.34407        0.234582            100242                  100
model.35.9.2  RandomUnstructured_0.90  0.413244    9.85853  0.343173       0.227128            100242                  100
model.35.9.3  RandomUnstructured_0.90  0.339916   10.7069   0.344289       0.231306            100242                  100
model.35.9.4  RandomUnstructured_0.90  0.424401    9.74283  0.342817       0.233438            100242                  100


In [59]:
for a_rate in trials_absolute:
    res = compute_pruning_evaluation(
        model = model,
        image_dataset = image_dataset,
        amount = a_rate,
        pruning_method = prune.RandomUnstructured,
        number_trials = NUMBER_TRIALS_ABS_PRUNING)
    eval_info_list.extend(res)
    pass

eval_field_names = "model_name,model_type,mse,psnr_db,ssim,eta_seconds,footprint_byte,footprint_percent".split(",")
table_vals = list(map(operator.methodcaller("values"), map(operator.methodcaller("_asdict"), res)))
table = tabulate.tabulate(table_vals[:5], headers=eval_field_names)
print(table)

model_name    model_type                   mse    psnr_db      ssim    eta_seconds    footprint_byte    footprint_percent
------------  ---------------------  ---------  ---------  --------  -------------  ----------------  -------------------
model.35.9.0  RandomUnstructured_50  0.137656     14.6329  0.444161       0.230111            100242                  100
model.35.9.1  RandomUnstructured_50  0.14814      14.3153  0.428472       0.228542            100242                  100
model.35.9.2  RandomUnstructured_50  0.0428245    19.7042  0.631822       0.230844            100242                  100
model.35.9.3  RandomUnstructured_50  0.0399226    20.0089  0.65495        0.233088            100242                  100
model.35.9.4  RandomUnstructured_50  0.0399359    20.0076  0.639358       0.227269            100242                  100


In [60]:
data = list(map(operator.methodcaller("_asdict"), eval_info_list))
df = pd.DataFrame(data = data)

In [61]:
df.head(5)

Unnamed: 0,model_name,model_type,mse,psnr_db,ssim,eta_seconds,footprint_byte,footprint_percent
0,model-35-9,Basic,0.000226,42.481681,0.980631,0.277036,52657,100.0
1,model.35.9.0,L1Unstructured_0.01,0.000235,42.314139,0.980416,0.245034,100242,100.0
2,model.35.9.1,L1Unstructured_0.01,0.000235,42.314139,0.980416,0.230156,100242,100.0
3,model.35.9.2,L1Unstructured_0.01,0.000235,42.314139,0.980416,0.231386,100242,100.0
4,model.35.9.3,L1Unstructured_0.01,0.000235,42.314139,0.980416,0.23338,100242,100.0


In [62]:
set(df['model_type'].values)

{'Basic',
 'L1Unstructured_0.01',
 'L1Unstructured_0.02',
 'L1Unstructured_0.03',
 'L1Unstructured_0.04',
 'L1Unstructured_0.05',
 'L1Unstructured_0.10',
 'L1Unstructured_0.20',
 'L1Unstructured_0.30',
 'L1Unstructured_0.40',
 'L1Unstructured_0.50',
 'L1Unstructured_0.60',
 'L1Unstructured_0.70',
 'L1Unstructured_0.80',
 'L1Unstructured_0.90',
 'L1Unstructured_10',
 'L1Unstructured_12',
 'L1Unstructured_16',
 'L1Unstructured_20',
 'L1Unstructured_30',
 'L1Unstructured_32',
 'L1Unstructured_4',
 'L1Unstructured_40',
 'L1Unstructured_5',
 'L1Unstructured_50',
 'L1Unstructured_8',
 'RandomUnstructured_0.01',
 'RandomUnstructured_0.02',
 'RandomUnstructured_0.03',
 'RandomUnstructured_0.04',
 'RandomUnstructured_0.05',
 'RandomUnstructured_0.10',
 'RandomUnstructured_0.20',
 'RandomUnstructured_0.30',
 'RandomUnstructured_0.40',
 'RandomUnstructured_0.50',
 'RandomUnstructured_0.60',
 'RandomUnstructured_0.70',
 'RandomUnstructured_0.80',
 'RandomUnstructured_0.90',
 'RandomUnstructured_10

In [63]:
df.groupby(by = ["model_type"]).mean().sort_values(by = ["psnr_db"]).tail(5)

Unnamed: 0_level_0,mse,psnr_db,ssim,eta_seconds,footprint_byte,footprint_percent
model_type,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
L1Unstructured_5,0.000226,42.481721,0.980631,0.233338,100242,100.0
L1Unstructured_12,0.000226,42.481809,0.980631,0.231515,100242,100.0
L1Unstructured_10,0.000226,42.481848,0.980631,0.233237,100242,100.0
L1Unstructured_16,0.000226,42.483049,0.980632,0.232032,100242,100.0
L1Unstructured_20,0.000226,42.483468,0.980632,0.228877,100242,100.0


In [64]:
def model_type_to_quant_tech(model_type):
    return model_type.split("_")[0]
df['quant_tech'] = list(map(model_type_to_quant_tech, df['model_type'].values))
df.tail(5)

Unnamed: 0,model_name,model_type,mse,psnr_db,ssim,eta_seconds,footprint_byte,footprint_percent,quant_tech
246,model.35.9.0,RandomUnstructured_50,0.137656,14.6329,0.444161,0.230111,100242,100.0,RandomUnstructured
247,model.35.9.1,RandomUnstructured_50,0.14814,14.315317,0.428472,0.228542,100242,100.0,RandomUnstructured
248,model.35.9.2,RandomUnstructured_50,0.042824,19.704189,0.631822,0.230844,100242,100.0,RandomUnstructured
249,model.35.9.3,RandomUnstructured_50,0.039923,20.008895,0.65495,0.233088,100242,100.0,RandomUnstructured
250,model.35.9.4,RandomUnstructured_50,0.039936,20.007629,0.639358,0.227269,100242,100.0,RandomUnstructured


In [65]:
def model_size_to_bpp(model_footprint, w = 256, h = 256):
    return model_footprint * 4 / (w * h)
df['bpp'] = list(map(model_size_to_bpp, df['footprint_byte'].values))
df.head(5)

Unnamed: 0,model_name,model_type,mse,psnr_db,ssim,eta_seconds,footprint_byte,footprint_percent,quant_tech,bpp
0,model-35-9,Basic,0.000226,42.481681,0.980631,0.277036,52657,100.0,Basic,3.213928
1,model.35.9.0,L1Unstructured_0.01,0.000235,42.314139,0.980416,0.245034,100242,100.0,L1Unstructured,6.118286
2,model.35.9.1,L1Unstructured_0.01,0.000235,42.314139,0.980416,0.230156,100242,100.0,L1Unstructured,6.118286
3,model.35.9.2,L1Unstructured_0.01,0.000235,42.314139,0.980416,0.231386,100242,100.0,L1Unstructured,6.118286
4,model.35.9.3,L1Unstructured_0.01,0.000235,42.314139,0.980416,0.23338,100242,100.0,L1Unstructured,6.118286


In [66]:
def model_size_to_bpp(model_footprint, w = 256, h = 256):
    return model_footprint * 4 / (w * h)
df['bpp'] = list(map(model_size_to_bpp, df['footprint_byte'].values))
df.tail(5)

Unnamed: 0,model_name,model_type,mse,psnr_db,ssim,eta_seconds,footprint_byte,footprint_percent,quant_tech,bpp
246,model.35.9.0,RandomUnstructured_50,0.137656,14.6329,0.444161,0.230111,100242,100.0,RandomUnstructured,6.118286
247,model.35.9.1,RandomUnstructured_50,0.14814,14.315317,0.428472,0.228542,100242,100.0,RandomUnstructured,6.118286
248,model.35.9.2,RandomUnstructured_50,0.042824,19.704189,0.631822,0.230844,100242,100.0,RandomUnstructured,6.118286
249,model.35.9.3,RandomUnstructured_50,0.039923,20.008895,0.65495,0.233088,100242,100.0,RandomUnstructured,6.118286
250,model.35.9.4,RandomUnstructured_50,0.039936,20.007629,0.639358,0.227269,100242,100.0,RandomUnstructured,6.118286


In [67]:
hue = 'quant_tech'; x = 'bpp'; y = 'psnr_db'
fig = px.scatter(df[df[f"{hue}"] != 'Basic'], x=f"{x}", y=f"{y}", color=f"{hue}", 
                 # marginal_y="violin", marginal_x="box", trendline="ols",
                 template=DASH_TEMPLATES_LIST[2])
fig.update_layout(template = DASH_TEMPLATES_LIST[2], title_text=f'{y.upper()} | Groupped by {hue} | dataframes')

In [68]:
hue = 'quant_tech'; x = 'bpp'; y = 'psnr_db'
fig = px.box(df[df[f"{hue}"] != 'Basic'], y=f"{y}", color=f"{hue}", 
                 # marginal_y="violin", marginal_x="box", trendline="ols",
                 template=DASH_TEMPLATES_LIST[2])
fig.update_layout(template = DASH_TEMPLATES_LIST[2], title_text=f'{y.upper()} | Groupped by {hue} | dataframes')

In [69]:
hue = 'quant_tech'; x = 'bpp'; y = 'psnr_db'
fig = px.violin(df[df[f"{hue}"] != 'Basic' ], y=f"{y}", color=f"{hue}", 
                 # marginal_y="violin", marginal_x="box", trendline="ols",
                 template=DASH_TEMPLATES_LIST[2])
fig.update_layout(template = DASH_TEMPLATES_LIST[2], title_text=f'{y.upper()} | Groupped by {hue} | dataframes')

In [70]:
def model_type_to_quant_tech_2(model_type):
    if model_type == 'Basic': return model_type
    quant_tech_2 = model_type.split("_")[0]
    value = int(float(model_type.split("_")[1]))
    if value != 0:
        return quant_tech_2 + "_" + "abs"
    return quant_tech_2 + "_" + "rate"
df['quant_tech_2'] = list(map(model_type_to_quant_tech_2, df['model_type'].values))
df["quant_tech_2"].tail(5)
# list(collections.Counter(df["quant_tech_2"].values))
list(set(df["quant_tech_2"].values))

['RandomUnstructured_rate',
 'L1Unstructured_abs',
 'Basic',
 'L1Unstructured_rate',
 'RandomUnstructured_abs']

In [71]:
hue = 'quant_tech_2'; x = 'bpp'; y = 'psnr_db'
fig = px.scatter(df[df[f"{hue}"] != 'Basic' ], x = f'{x}', y=f"{y}", color=f"{hue}", 
                 # marginal_y="violin", marginal_x="box", trendline="ols",
                 template=DASH_TEMPLATES_LIST[2])
fig.update_layout(template = DASH_TEMPLATES_LIST[2], title_text=f'{y.upper()} | Groupped by {hue} | dataframes')

In [72]:
hue = 'quant_tech_2'; x = 'bpp'; y = 'psnr_db'
fig = px.box(df[df[f"{hue}"] != 'Basic' ], y=f"{y}", color=f"{hue}", 
                 # marginal_y="violin", marginal_x="box", trendline="ols",
                 template=DASH_TEMPLATES_LIST[2])
fig.update_layout(template = DASH_TEMPLATES_LIST[2], title_text=f'{y.upper()} | Groupped by {hue} | dataframes')

In [73]:
hue = 'quant_tech_2'; x = 'bpp'; y = 'psnr_db'
fig = px.violin(df[df[f"{hue}"] != 'Basic' ], y=f"{y}", color=f"{hue}", 
                 # marginal_y="violin", marginal_x="box", trendline="ols",
                 template=DASH_TEMPLATES_LIST[2])
fig.update_layout(template = DASH_TEMPLATES_LIST[2], title_text=f'{y.upper()} | Groupped by {hue} | dataframes')