### Setup Notebook

In [None]:
#@title Enable prune mode
#@markdown ---
PRUNE_MODEL = False                   #@param {type:"boolean"}
PRUNE_A_FILE = False                   #@param {type:"boolean"}
PRUNE_CSV_FILES = False                #@param {type:"boolean"}

SHOW_PLOTS = False                #@param {type:"boolean"}

In [None]:
#@title Github related infos
#@markdown ---
PROJECT_NAME_PATH = '/content/My-Siren-Deep-Learning-Test' #@param {type:"string"}

    
GITHUB_PROJECT_URL = 'https://github.com/franec94/My-Siren-Deep-Learning-Test.git' #@param {type:"string"}
BRANCH_NAME = 'cmd-line-tools' #@param {type:"string"}
CMD_TOOL_NAME = 'post-train-cmd-line-tools/prune-eval-tool' #@param {type:"string"}

In [None]:
#@title Model's Hyper-Params
#@markdown ---
N_HF=35 #@param {type:"integer"}
N_HL=9  #@param {type:"integer"}
SIDELENGTH=256 #@param {type:"integer"}
DEVICE = "cpu" #@param ["cpu", "cuda", "gpu"]
BATCH_SIZE=1 #@param {type:"integer"}
MODEL_PATH='/content/model_final.pth' #@param {type:"string"}

In [None]:
FROM_UPLOADED_FILE = False #@param {type:"boolean"}
UPLOADED_FILE_NAME = "/content/final_result_pruning.csv" #@param {type:"string"}

In [None]:
#@markdown ---
#@markdown ##### Data csv Info:
TIMESTAMP_VAL = "1605077314-991078" #@param {type:"string"}
EVAL_CSV_FILES = False  #@param {type:"boolean"}
EVAL_ALL_CSV_FILES = False  #@param {type:"boolean"}

In [None]:
DATA_ZIP_PATH = f"/content/{TIMESTAMP_VAL}.zip"
DATA_PATH = f"/content/{TIMESTAMP_VAL}" 
DATA_CSV_PATH = f"/content/{TIMESTAMP_VAL}/colab_{TIMESTAMP_VAL}.csv"

DATA_ZIP_PATH, DATA_PATH, DATA_CSV_PATH

('/content/1605077314-991078.zip',
 '/content/1605077314-991078',
 '/content/1605077314-991078/colab_1605077314-991078.csv')

In [None]:
if PRUNE_MODEL and EVAL_CSV_FILES:
    import os
    print("Extracting data...")
    if os.path.exists(DATA_ZIP_PATH):
        print(f"{DATA_ZIP_PATH} exists!")
        if os.path.isdir(f"{DATA_PATH}"):
            !unzip {DATA_ZIP_PATH} -o -d /content/ > /dev/null
        else:
            !unzip {DATA_ZIP_PATH} -d /content/ > /dev/null
    else:
        print(f"{DATA_ZIP_PATH} does not exist!")
    pass

In [None]:
if PRUNE_MODEL and EVAL_ALL_CSV_FILES:
    DATA_CSV_PATH = []
    import pathlib
    zip_files_list = pathlib.Path(f'/content').glob(f'*.zip')
    for a_zip_file in zip_files_list:
        print("Extracting data...")
        filename = os.path.basename(a_zip_file)
        filename = os.path.splitext(filename)[0]
        DATA_CSV_PATH.append(f"/content/{filename}/colab_{filename}.csv")
        if os.path.exists(a_zip_file):
            print(f"{a_zip_file} exists!")
            !unzip {a_zip_file} -d /content/ > /dev/null
        else:
            print(f"{a_zip_file} does not exist!")
        pass
    DATA_CSV_PATH = ' '.join(DATA_CSV_PATH)
    print(DATA_CSV_PATH)
    pass

In [None]:
#@title Save results
#@markdown ---
LOGGING_ROOT = '/content/results/cameramen' #@param {type:"string"}
EXPERIMENT_NAME = 'train' #@param {type:"string"}


In [None]:
#@title Handle workspace
#@markdown ---
CLEAR_RESULTS_DIR = False      #@param {type:"boolean"}
RESULTS_DIR_PATH = "/content/results"      #@param {type:"string"}

In [None]:
if CLEAR_RESULTS_DIR:
    import os
    if os.path.exists(RESULTS_DIR_PATH) and os.path.isdir(RESULTS_DIR_PATH):
        print(f"Clearing {RESULTS_DIR_PATH}...")
        !rm -R {RESULTS_DIR_PATH}/*
        !rmdir {RESULTS_DIR_PATH}
        pass
    pass

### Imports

In [None]:
# Installing third party dependencies
print("Installing required libraries...")

old_requirements = '/content/tmp_requirements.txt'
!pip freeze > {old_requirements}
dependencies_list = "cmapy,sk-video,pytorch-model-summary,ConfigArgParse,tabulate,chart_studio,dash,dash_bootstrap_components".split(",")

with open(old_requirements) as f:
    old_requirements_list = f.read().split("\n")
    for a_req in dependencies_list:
        found_req = False
        for old_req in old_requirements_list:
            if old_req.startswith(a_req):
                print(f"{a_req} already installed!")
                found_req = True
                break
        if found_req is False:
            !pip install {a_req} -q
    pass
!rm -f {old_requirements}

Installing required libraries...
cmapy already installed!
sk-video already installed!
pytorch-model-summary already installed!
ConfigArgParse already installed!
tabulate already installed!
dash already installed!


In [None]:
from __future__ import print_function
from __future__ import division

# --------------------------------------------- #
# Standard Library, plus some Third Party Libraries
# --------------------------------------------- #

DASH_TEMPLATES_LIST = ["plotly", "plotly_white", "plotly_dark", "ggplot2", "seaborn", "simple_white", "none"]

from PIL import Image
from functools import partial
from pprint import pprint
from tqdm import tqdm
from typing import Tuple, Union


import configargparse
import copy
import collections
import datetime
import itertools
import functools
import h5py
import logging
import math
import os
import operator
import pickle
import random
import shutil
import sys
import re
import tabulate 
import time
# import visdom


from collections import OrderedDict
import matplotlib
import matplotlib.pyplot as plt
import numpy as np

# --------------------------------------------- #
# Data Science and Machine Learning Libraries
# --------------------------------------------- #
import matplotlib
import matplotlib.pyplot as plt
matplotlib.style.use('ggplot')

import numpy as np
import pandas as pd
import sklearn

from sklearn.model_selection import ParameterGrid
from sklearn.model_selection import train_test_split

# --------------------------------------------- #
# Torch
# --------------------------------------------- #
try:
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import torch.optim as optim
    from torch.utils.data import DataLoader, Dataset
    import torch.quantization
    import torch.nn.utils.prune as prune
except:
    print("torch not available!")
    pass


# --------------------------------------------- #
# Import: TorchVision
# --------------------------------------------- #
try:
    import torchvision
    from torchvision import datasets
    from torchvision import transforms
    from torchvision.transforms import Resize, Compose, ToTensor, CenterCrop, Normalize
    from torchvision.utils import save_image
except:
    print("torchvision library not available!")
    pass

# Plotly imports.
# ----------------------------------------------- #
import chart_studio.plotly as py
import plotly.figure_factory as ff
import plotly.express as px

# --------------------------------------------- #
# Import: skimage
# --------------------------------------------- #
try:
    import skimage
    import skimage.metrics as skmetrics
    from skimage.metrics import peak_signal_noise_ratio as psnr
    from skimage.metrics import structural_similarity as ssim
    from skimage.metrics import mean_squared_error
except:
    print("skimage library not available!")
    pass

### Fetch Github project

In [None]:
import os
os.chdir('/content')

In [None]:
# Setup wd to remove trash

if PRUNE_MODEL:
    # Remove trash
    import os
    if os.path.exists(f'{PROJECT_NAME_PATH}') and os.path.isdir(f'{PROJECT_NAME_PATH}'):
        print("Cleaning from old project...")
        !rm -r {PROJECT_NAME_PATH}
    else:
        print("No project found.")
        pass

    import os
    logs_base_dir = os.path.join("/content/outputs", "summaries")
    os.makedirs(logs_base_dir, exist_ok=True)

    if os.path.exists(PROJECT_NAME_PATH) is False:
        !git clone {GITHUB_PROJECT_URL}
        os.chdir(PROJECT_NAME_PATH)
        !git checkout {BRANCH_NAME}
        full_path_cmd = os.path.join(PROJECT_NAME_PATH, f'dev-cmd-line-tools/{CMD_TOOL_NAME}')
        os.chdir(full_path_cmd)
    else:
        os.chdir(PROJECT_NAME_PATH)
        !git checkout {BRANCH_NAME}
        !git fetch
        full_path_cmd = os.path.join(PROJECT_NAME_PATH, f'dev-cmd-line-tools/{CMD_TOOL_NAME}')
        os.chdir(full_path_cmd)
        pass
    pass
else:
    print("No github project cloned and no branch activated and switched to!")
    pass

No github project cloned and no branch activated and switched to!


### Run program

In [None]:
# Show help
if PRUNE_MODEL:
    !python main.py --help

In [None]:
if PRUNE_MODEL and PRUNE_A_FILE:
    !python main.py \
        --logging_root {LOGGING_ROOT} \
        --experiment_name {EXPERIMENT_NAME} \
        --models_filepath {MODEL_PATH} \
        --sidelength {SIDELENGTH} \
        --n_hf {N_HF} \
        --n_hl {N_HL} \
        --global_pruning_techs 'L1Unstructured' 'RandomUnstructured' \
        --global_pruning_rates .01 .02 .03 .04 .05 .06 .07 .08 .1 .2 .3 .4 .5 .6 .7 .8 .9 \
        --global_pruning_abs 10 20 30 40 50 60 70 80 90 100 150 200 500 1000 1500 1700 2000 2500 2700 3000\
        --dynamic_quant qint8 qfloat16 \
        --verbose 0
    pass
else:
    print("No single architecture's state dict file pruned!")
    pass

No single architecture's state dict file pruned!


In [None]:
if PRUNE_MODEL and PRUNE_CSV_FILES:
    """
    !python main.py \ 
        --logging_root {LOGGING_ROOT} \
        --experiment_name {EXPERIMENT_NAME} \
        --csv_files {DATA_CSV_PATH} \
        --sidelength {SIDELENGTH} \
        --global_pruning_techs 'L1Unstructured' 'RandomUnstructured' \
        --global_pruning_rates .01 .02 .03 .04 .05 .06 .07 .08 .1 .2 .3 .4 .5 .6 .7 .8 .9 \
        --global_pruning_abs 10 20 30 40 50 60 70 80 90 100 150 200 500 1000 1500 1700 2000 2500 2700 3000 4000 5000 \      
        --dynamic_quant qint8 qfloat16 \
        --verbose 0
    """
    pass
else:
    print("No many architectures state dict file pruned!")
    pass
# --global_abs_linspace 50 \

No many architectures state dict file pruned!
