### Setup Notebook

In [1]:
#@title Enable prune mode
#@markdown ---
PRUNE_MODEL = False                   #@param {type:"boolean"}

In [2]:
#@title Github related infos
#@markdown ---
PROJECT_NAME_PATH = '/content/My-Siren-Deep-Learning-Test' #@param {type:"string"}

    
GITHUB_PROJECT_URL = 'https://github.com/franec94/My-Siren-Deep-Learning-Test.git' #@param {type:"string"}
BRANCH_NAME = 'cmd-line-tools' #@param {type:"string"}
CMD_TOOL_NAME = 'post-train-cmd-line-tools/prune-eval-tool' #@param {type:"string"}

In [3]:
#@title Model's Hyper-Params
#@markdown ---
N_HF=35 #@param {type:"integer"}
N_HL=9  #@param {type:"integer"}
SIDELENGTH=256 #@param {type:"integer"}
DEVICE = "cpu" #@param ["cpu", "cuda", "gpu"]
BATCH_SIZE=1 #@param {type:"integer"}
MODEL_PATH='/content/model_final.pth' #@param {type:"string"}

In [4]:
#@title Save results
#@markdown ---
LOGGING_ROOT = '/content/results/cameramen' #@param {type:"string"}
EXPERIMENT_NAME = 'train' #@param {type:"string"}


### Imports

In [5]:
from __future__ import print_function
from __future__ import division

# --------------------------------------------- #
# Standard Library, plus some Third Party Libraries
# --------------------------------------------- #

DASH_TEMPLATES_LIST = ["plotly", "plotly_white", "plotly_dark", "ggplot2", "seaborn", "simple_white", "none"]

from PIL import Image
from functools import partial
from pprint import pprint
from tqdm import tqdm
from typing import Tuple, Union


import configargparse
import copy
import collections
import datetime
import functools
import h5py
import logging
import math
import os
import operator
import pickle
import random
import shutil
import sys
import re
import tabulate 
import time
# import visdom


from collections import OrderedDict
import matplotlib
import matplotlib.pyplot as plt
import numpy as np

# --------------------------------------------- #
# Data Science and Machine Learning Libraries
# --------------------------------------------- #
import matplotlib
import matplotlib.pyplot as plt
matplotlib.style.use('ggplot')

import numpy as np
import pandas as pd
import sklearn

from sklearn.model_selection import ParameterGrid
from sklearn.model_selection import train_test_split

# --------------------------------------------- #
# Torch
# --------------------------------------------- #
try:
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import torch.optim as optim
    from torch.utils.data import DataLoader, Dataset
    import torch.quantization
    import torch.nn.utils.prune as prune
except:
    print("torch not available!")
    pass


# --------------------------------------------- #
# Import: TorchVision
# --------------------------------------------- #
try:
    import torchvision
    from torchvision import datasets
    from torchvision import transforms
    from torchvision.transforms import Resize, Compose, ToTensor, CenterCrop, Normalize
    from torchvision.utils import save_image
except:
    print("torchvision library not available!")
    pass

# Plotly imports.
# ----------------------------------------------- #
import chart_studio.plotly as py
import plotly.figure_factory as ff
import plotly.express as px

# --------------------------------------------- #
# Import: skimage
# --------------------------------------------- #
try:
    import skimage
    import skimage.metrics as skmetrics
    from skimage.metrics import peak_signal_noise_ratio as psnr
    from skimage.metrics import structural_similarity as ssim
    from skimage.metrics import mean_squared_error
except:
    print("skimage library not available!")
    pass

### Fetch Github project

In [6]:
# Installing third party dependencies
print("Installing required libraries...")

old_requirements = '/content/tmp_requirements.txt'
!pip freeze > {old_requirements}
dependencies_list = "cmapy,sk-video,pytorch-model-summary,ConfigArgParse,tabulate,chart_studio,dash,dash_bootstrap_components".split(",")

with open(old_requirements) as f:
    old_requirements_list = f.read().split("\n")
    for a_req in dependencies_list:
        found_req = False
        for old_req in old_requirements_list:
            if old_req.startswith(a_req):
                print(f"{a_req} already installed!")
                found_req = True
                break
        if found_req is False:
            !pip install {a_req} -q
    pass
!rm -f {old_requirements}

Installing required libraries...
cmapy already installed!
sk-video already installed!
pytorch-model-summary already installed!
ConfigArgParse already installed!
tabulate already installed!
dash already installed!


In [7]:
import os
os.chdir('/content')

In [8]:
# Setup wd to remove trash

if PRUNE_MODEL:
    # Remove trash
    import os
    if os.path.exists(f'{PROJECT_NAME_PATH}') and os.path.isdir(f'{PROJECT_NAME_PATH}'):
        print("Cleaning from old project...")
        !rm -r {PROJECT_NAME_PATH}
    else:
        print("No project found.")
        pass

    import os
    logs_base_dir = os.path.join("/content/outputs", "summaries")
    os.makedirs(logs_base_dir, exist_ok=True)

    if os.path.exists(PROJECT_NAME_PATH) is False:
        !git clone {GITHUB_PROJECT_URL}
        os.chdir(PROJECT_NAME_PATH)
        !git checkout {BRANCH_NAME}
        full_path_cmd = os.path.join(PROJECT_NAME_PATH, f'dev-cmd-line-tools/{CMD_TOOL_NAME}')
        os.chdir(full_path_cmd)
    else:
        os.chdir(PROJECT_NAME_PATH)
        !git checkout {BRANCH_NAME}
        !git fetch
        full_path_cmd = os.path.join(PROJECT_NAME_PATH, f'dev-cmd-line-tools/{CMD_TOOL_NAME}')
        os.chdir(full_path_cmd)
        pass
    pass

### Run program

In [9]:
# Show help
if PRUNE_MODEL:
    !python main.py --help

In [10]:
if PRUNE_MODEL:
    !python main.py \
        --logging_root {LOGGING_ROOT} \
        --experiment_name {EXPERIMENT_NAME} \
        --models_filepath {MODEL_PATH} \
        --sidelength {SIDELENGTH} \
        --n_hf {N_HF} \
        --n_hl {N_HL} \
        --global_pruning_techs 'L1Unstructured' 'RandomUnstructured' \
        --global_pruning_rates .01 .02 .03 .04 .05 .06 .07 .08 .1 .2 .3 .4 .5 .6 .7 .8 .9 \
        --global_pruning_abs 10 20 30 40 50 60 70 80 90 100 150 200 \
        --dynamic_quant qint8 qfloat16 \
        --verbose 0
    pass

### Checkout results

In [11]:
import pandas as pd

In [13]:
df = pd.read_csv('/content/results/cameramen/15-11-2020/1605455848-997809/trainresults.csv').drop(['Unnamed: 0'], axis = 1)

In [20]:
df.head(5)

Unnamed: 0,model_name,mse,psnr_db,ssim,eta_seconds,footprint_byte,footprint_percent,bpp,quant_tech,quant_tech_2,prune_amount
0,model.35.9.0,0.163425,13.888062,0.367754,0.171121,100339,100.0,6.124207,RandomUnstructured,RandomUnstructured_rate,0.01
1,model.35.9.1,0.201791,12.972258,0.320147,0.157675,100339,100.0,6.124207,RandomUnstructured,RandomUnstructured_rate,0.01
2,model.35.9.2,0.07534,17.250423,0.521151,0.148324,100339,100.0,6.124207,RandomUnstructured,RandomUnstructured_rate,0.01
3,model.35.9.3,0.088674,16.542898,0.488663,0.150604,100339,100.0,6.124207,RandomUnstructured,RandomUnstructured_rate,0.01
4,model.35.9.4,0.168056,13.766441,0.350631,0.151277,100339,100.0,6.124207,RandomUnstructured,RandomUnstructured_rate,0.01


In [21]:
df.tail(5)

Unnamed: 0,model_name,mse,psnr_db,ssim,eta_seconds,footprint_byte,footprint_percent,bpp,quant_tech,quant_tech_2,prune_amount
575,model.35.9.5,0.000271,41.696013,0.979616,0.14742,100339,100.0,6.124207,L1Unstructured,L1Unstructured_abs,200.0
576,model.35.9.6,0.000271,41.696013,0.979616,0.15491,100339,100.0,6.124207,L1Unstructured,L1Unstructured_abs,200.0
577,model.35.9.7,0.000271,41.696013,0.979616,0.148178,100339,100.0,6.124207,L1Unstructured,L1Unstructured_abs,200.0
578,model.35.9.8,0.000271,41.696013,0.979616,0.152714,100339,100.0,6.124207,L1Unstructured,L1Unstructured_abs,200.0
579,model.35.9.9,0.000271,41.696013,0.979616,0.152576,100339,100.0,6.124207,L1Unstructured,L1Unstructured_abs,200.0


In [14]:
hue = 'quant_tech'; x = 'bpp'; y = 'psnr_db'
fig = px.scatter(df[df[f"{hue}"] != 'Basic'], x=f"{x}", y=f"{y}", color=f"{hue}", 
                 # marginal_y="violin", marginal_x="box", trendline="ols",
                 template=DASH_TEMPLATES_LIST[2])
fig.update_layout(template = DASH_TEMPLATES_LIST[2], title_text=f'{y.upper()} | Groupped by {hue} | dataframes')

In [15]:
hue = 'quant_tech_2'; x = 'bpp'; y = 'psnr_db'
fig = px.scatter(df[df[f"{hue}"] != 'Basic'], x=f"{x}", y=f"{y}", color=f"{hue}", 
                 # marginal_y="violin", marginal_x="box", trendline="ols",
                 template=DASH_TEMPLATES_LIST[2])
fig.update_layout(template = DASH_TEMPLATES_LIST[2], title_text=f'{y.upper()} | Groupped by {hue} | dataframes')

In [16]:
hue = 'quant_tech'; x = 'bpp'; y = 'psnr_db'
fig = px.box(df[df[f"{hue}"] != 'Basic'], y=f"{y}", color=f"{hue}", 
                 # marginal_y="violin", marginal_x="box", trendline="ols",
                 template=DASH_TEMPLATES_LIST[2])
fig.update_layout(template = DASH_TEMPLATES_LIST[2], title_text=f'{y.upper()} | Groupped by {hue} | dataframes')

In [17]:
hue = 'quant_tech_2'; x = 'bpp'; y = 'psnr_db'
fig = px.box(df[df[f"{hue}"] != 'Basic'], y=f"{y}", color=f"{hue}", 
                 # marginal_y="violin", marginal_x="box", trendline="ols",
                 template=DASH_TEMPLATES_LIST[2])
fig.update_layout(template = DASH_TEMPLATES_LIST[2], title_text=f'{y.upper()} | Groupped by {hue} | dataframes')

In [18]:
hue = 'quant_tech'; x = 'bpp'; y = 'psnr_db'
fig = px.violin(df[df[f"{hue}"] != 'Basic' ], y=f"{y}", color=f"{hue}", 
                 # marginal_y="violin", marginal_x="box", trendline="ols",
                 template=DASH_TEMPLATES_LIST[2])
fig.update_layout(template = DASH_TEMPLATES_LIST[2], title_text=f'{y.upper()} | Groupped by {hue} | dataframes')

In [19]:
hue = 'quant_tech_2'; x = 'bpp'; y = 'psnr_db'
fig = px.violin(df[df[f"{hue}"] != 'Basic' ], y=f"{y}", color=f"{hue}", 
                 # marginal_y="violin", marginal_x="box", trendline="ols",
                 template=DASH_TEMPLATES_LIST[2])
fig.update_layout(template = DASH_TEMPLATES_LIST[2], title_text=f'{y.upper()} | Groupped by {hue} | dataframes')