# Siren Case Study: Training-Test via Distiller Framework
---
### Training a siren model, enabling also pruning, quantization if necessary

## Options

In [1]:
#@title Github related infos:
#@markdown ---
PROJECT_NAME_PATH = '/content/distiller' #@param {type:"string"}

    
GITHUB_PROJECT_URL = 'https://github.com/franec94/distiller.git' #@param {type:"string"}
BRANCH_NAME = 'siren-support' #@param {type:"string"}
CMD_TOOL_NAME = '' #@param {type:"string"}

In [2]:
#@title Install Dependencies -> Options:
#@markdown ---
CLONE_GITHUB_PROJECT = True #@param {type:"boolean"}
INSTALL_SOME_LIBS = False #@param {type:"boolean"}
INSTALL_DISTILLER_PIP_REQ = False #@param {type:"boolean"}

In [3]:
#@title Running -> Options:
#@markdown ---
RUN_COLAB_CODE = False #@param {type:"boolean"}
RUN_MAIN_SIREN_BASE = False #@param {type:"boolean"}
RUN_MAIN_SIREN_APP = True #@param {type:"boolean"}
EVAL_MODEL = False #@param {type:"boolean"}
PLAIN_TRAINING = False #@param {type:"boolean"}

In [4]:
#@title Dir Path(s) for Logging -> options:
#@markdown ---
LOGGING_ROOT = "/content/drive/MyDrive/Siren-Deep-Learning-Analyses/results/cameramen/agp_pruning/" #@param ["/content/drive/MyDrive/Siren-Deep-Learning-Analyses/results/cameramen/sensitivity_pruning/", "/content/pruning/sensitivity_pruning", "/content/pruning/level_pruning", "/content/drive/MyDrive/Siren-Deep-Learning-Analyses/results/cameramen/agp_pruning/"] {allow-input: true}
LOGGING_ROOT_EVAL = "/content/drive/MyDrive/Siren-Deep-Learning-Analyses/results/cameramen/eval-res" #@param {type:"string"}
EXPERIMENT_NAME = 'train' #@param {type:"string"}

In [5]:
#@title DNN Arch. -> Options:
#@markdown ---
N_HF = 64 #@param {type:"integer"}
N_HL = 5 #@param {type:"integer"}
NUM_EPOCHS =  35000#@param {type:"integer"}
SIDELENGHT = 256 #@param {type:"integer"}

In [6]:
#@title Pruning -> Options:
#@markdown ---
PRUNE_MODEL_AGP  = True #@param {type:"boolean"}
SENSITIVITY_PRUNING  = False #@param {type:"boolean"}
LEVEL_PRUNING  = False #@param {type:"boolean"}
PRUNING_SCHEDULER_FILE = "/content/siren64_5.schedule_agp.yaml" #@param {type:"string"}
STATE_DICT_MODEL_FILE = "/content/_mid_ckpt_epoch_299999.pth.tar" #@param {type:"string"}

In [7]:
#@title Tensorboard -> options:
#@markdown ---
# RUN_TENSORBOARD_UTIL = False #@param {type:"boolean"}
LOG_DIR = "/content/drive/MyDrive/Siren-Deep-Learning-Analyses/results/cameramen/sensitivity_pruning/" #@param {type:"string"}

## Setup Notebook

### Clone Project

In [8]:
import os
os.chdir('/content')

In [9]:
# Setup wd to remove trash

if CLONE_GITHUB_PROJECT:
    # Remove trash
    import os
    if os.path.exists(f'{PROJECT_NAME_PATH}') and os.path.isdir(f'{PROJECT_NAME_PATH}'):
        print("Cleaning from old project...")
        !rm -r {PROJECT_NAME_PATH}
    else:
        print("No project found.")
        pass

    import os
    logs_base_dir = os.path.join("/content/outputs", "summaries")
    os.makedirs(logs_base_dir, exist_ok=True)

    if os.path.exists(PROJECT_NAME_PATH) is False:
        !git clone {GITHUB_PROJECT_URL}
        os.chdir(PROJECT_NAME_PATH)
        !git checkout {BRANCH_NAME}
        if CMD_TOOL_NAME == None or len(CMD_TOOL_NAME) == 0:
            full_path_cmd = os.path.join(PROJECT_NAME_PATH, f'{PROJECT_NAME_PATH}/')
            os.chdir(full_path_cmd)
        else:
            full_path_cmd = os.path.join(PROJECT_NAME_PATH, f'{PROJECT_NAME_PATH}/{CMD_TOOL_NAME}')
            os.chdir(full_path_cmd)
    else:
        os.chdir(PROJECT_NAME_PATH)
        !git checkout {BRANCH_NAME}
        !git fetch
        if CMD_TOOL_NAME == None or len(CMD_TOOL_NAME) == 0:
            full_path_cmd = os.path.join(PROJECT_NAME_PATH, f'{PROJECT_NAME_PATH}/')
            os.chdir(full_path_cmd)
        else:
            full_path_cmd = os.path.join(PROJECT_NAME_PATH, f'{PROJECT_NAME_PATH}/{CMD_TOOL_NAME}')
            os.chdir(full_path_cmd)
        pass
    pass
else:
    print("No github project cloned and no branch activated and switched to!")
    pass

Cleaning from old project...
Cloning into 'distiller'...
remote: Enumerating objects: 462, done.[K
remote: Counting objects: 100% (462/462), done.[K
remote: Compressing objects: 100% (284/284), done.[K
remote: Total 7670 (delta 304), reused 326 (delta 172), pack-reused 7208[K
Receiving objects: 100% (7670/7670), 54.94 MiB | 29.83 MiB/s, done.
Resolving deltas: 100% (5414/5414), done.
Branch 'siren-support' set up to track remote branch 'siren-support' from 'origin'.
Switched to a new branch 'siren-support'


In [10]:
!pwd

/content/distiller


In [11]:
# !pip install -e .
if INSTALL_DISTILLER_PIP_REQ:
    !pip install -r requirements.txt

### Libs

In [12]:
# Installing third party dependencies
if INSTALL_SOME_LIBS:
    print("Installing required libraries...")

    old_requirements = '/content/tmp_requirements.txt'
    !pip freeze > {old_requirements}
    dependencies_list = "pretrainedmodels,torchnet,xlsxwriter,gitpython,python-git,cmapy,sk-video,pytorch-model-summary,ConfigArgParse,tabulate,chart_studio,dash,dash_bootstrap_components".split(",")

    with open(old_requirements) as f:
        old_requirements_list = f.read().split("\n")
        for a_req in dependencies_list:
            found_req = False
            for old_req in old_requirements_list:
                if old_req.startswith(a_req):
                    print(f"{a_req} already installed!")
                    found_req = True
                    break
            if found_req is False:
                !pip install {a_req} -q
        pass
    !rm -f {old_requirements}
    pass

In [13]:
%matplotlib inline
from __future__ import print_function
from __future__ import division


if RUN_COLAB_CODE:


    # --------------------------------------------- #
    # Standard Library, plus some Third Party Libraries
    # --------------------------------------------- #

    DASH_TEMPLATES_LIST = ["plotly", "plotly_white", "plotly_dark", "ggplot2", "seaborn", "simple_white", "none"]

    from PIL import Image
    from functools import partial
    from pprint import pprint
    from tqdm import tqdm
    from typing import Tuple, Union


    import configargparse
    import copy
    import collections
    import datetime
    import functools
    import itertools
    import h5py
    import logging
    import math
    import os
    import operator
    import pickle
    import random
    import shutil
    import sys
    import re
    import tabulate 
    import time
    # import visdom


    from collections import OrderedDict
    import matplotlib
    import matplotlib.pyplot as plt
    import numpy as np

    # --------------------------------------------- #
    # Data Science and Machine Learning Libraries
    # --------------------------------------------- #
    import matplotlib
    import matplotlib.pyplot as plt
    matplotlib.style.use('ggplot')
    import seaborn as sns

    import numpy as np
    import pandas as pd
    import sklearn

    from sklearn.model_selection import ParameterGrid
    from sklearn.model_selection import train_test_split

    # --------------------------------------------- #
    # Torch
    # --------------------------------------------- #
    import torch
    try:
        import torch
        import torch.nn as nn
        import torch.nn.functional as F
        import torch.optim as optim
        from torch.utils.data import DataLoader, Dataset
        # import torch.quantization
        # import torch.nn.utils.prune as prune

        from torch import nn, optim

    except Exception as err:
        print(err)
        print("torch not available!")
        pass

    from numpy import linalg as LA
    from scipy.stats import rankdata
    from collections import OrderedDict
    from torchvision import datasets, transforms
    from torch.utils.data.sampler import SubsetRandomSampler

    # --------------------------------------------- #
    # Import: torch_pruning
    # --------------------------------------------- #
    # import torch_pruning as tp


    # --------------------------------------------- #
    # Import: TorchVision
    # --------------------------------------------- #
    try:
        import torchvision
        from torchvision import datasets
        from torchvision import transforms
        from torchvision.transforms import Resize, Compose, ToTensor, CenterCrop, Normalize
        from torchvision.utils import save_image
    except:
        print("torchvision library not available!")
        pass

    # Plotly imports.
    # ----------------------------------------------- #
    import chart_studio.plotly as py
    import plotly.figure_factory as ff
    import plotly.express as px

    # --------------------------------------------- #
    # Import: skimage
    # --------------------------------------------- #
    try:
        import skimage
        import skimage.metrics as skmetrics
        from skimage.metrics import peak_signal_noise_ratio as psnr
        from skimage.metrics import structural_similarity as ssim
        from skimage.metrics import mean_squared_error
    except:
        print("skimage library not available!")
        pass

### PyTorch Architectures

In [14]:
if RUN_COLAB_CODE:
    class SineLayer(nn.Module):
        # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of omega_0.
        
        # If is_first=True, omega_0 is a frequency factor which simply multiplies the activations before the 
        # nonlinearity. Different signals may require different omega_0 in the first layer - this is a 
        # hyperparameter.
        
        # If is_first=False, then the weights will be divided by omega_0 so as to keep the magnitude of 
        # activations constant, but boost gradients to the weight matrix (see supplement Sec. 1.5)
        
        def __init__(self, in_features, out_features, bias=True,
                    is_first=False, omega_0=30):
            super().__init__()
            self.omega_0 = omega_0
            self.is_first = is_first
            
            self.in_features = in_features
            self.linear = nn.Linear(in_features, out_features, bias=bias)
            
            self.init_weights()
            pass
        
        def init_weights(self):
            with torch.no_grad():
                if self.is_first:
                    self.linear.weight.uniform_(-1 / self.in_features, 
                                                1 / self.in_features)      
                else:
                    self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0, 
                                                np.sqrt(6 / self.in_features) / self.omega_0)
            pass
            
        def forward(self, input):
            return torch.sin(self.omega_0 * self.linear(input))
        
        def forward_with_intermediate(self, input): 
            # For visualization of activation distributions
            intermediate = self.omega_0 * self.linear(input)
            return torch.sin(intermediate), intermediate
        pass
        
        
    class Siren(nn.Module):
        def __init__(self, in_features, hidden_features, hidden_layers, out_features, outermost_linear=False, 
                    first_omega_0=30, hidden_omega_0=30.):
            super().__init__()
            
            self.net = []
            self.net.append(SineLayer(in_features, hidden_features, 
                                    is_first=True, omega_0=first_omega_0))

            for i in range(hidden_layers):
                self.net.append(SineLayer(hidden_features, hidden_features, 
                                        is_first=False, omega_0=hidden_omega_0))

            if outermost_linear:
                final_linear = nn.Linear(hidden_features, out_features)
                
                with torch.no_grad():
                    final_linear.weight.uniform_(-np.sqrt(6 / hidden_features) / hidden_omega_0, 
                                                np.sqrt(6 / hidden_features) / hidden_omega_0)
                    
                self.net.append(final_linear)
            else:
                self.net.append(SineLayer(hidden_features, out_features, 
                                        is_first=False, omega_0=hidden_omega_0))
            
            self.net = nn.Sequential(*self.net)
            pass
        
        def forward(self, coords):
            coords = coords.clone().detach().requires_grad_(True) # allows to take derivative w.r.t. input
            output = self.net(coords)
            return output, coords        

        def forward_with_activations(self, coords, retain_grad=False):
            '''Returns not only model output, but also intermediate activations.
            Only used for visualizing activations later!'''
            activations = OrderedDict()

            activation_count = 0
            x = coords.clone().detach().requires_grad_(True)
            activations['input'] = x
            for i, layer in enumerate(self.net):
                if isinstance(layer, SineLayer):
                    x, intermed = layer.forward_with_intermediate(x)
                    
                    if retain_grad:
                        x.retain_grad()
                        intermed.retain_grad()
                        
                    activations['_'.join((str(layer.__class__), "%d" % activation_count))] = intermed
                    activation_count += 1
                else: 
                    x = layer(x)
                    
                    if retain_grad:
                        x.retain_grad()
                        
                activations['_'.join((str(layer.__class__), "%d" % activation_count))] = x
                activation_count += 1

            return activations
        pass

## Training Model

In [15]:
import os
os.chdir('/content/distiller')

In [16]:
import distiller

In [17]:
distiller.__version__

'Unknown'

### Colab mode

In [18]:
import os
import distiller
import torch.nn as nn
from distiller.models import register_user_model
import distiller.apputils.siren_image_regressor as regressor

In [19]:
if RUN_COLAB_CODE:
    def siren_model_64_5():
        img_siren = Siren(in_features=2, out_features=1, hidden_features=64, 
                  hidden_layers=5, outermost_linear=True)
        return img_siren

In [20]:
if RUN_COLAB_CODE:
    distiller.models.register_user_model(arch="SirenModel_64_5", dataset="cameramen", model=siren_model_64_5)
    model = distiller.models.create_model(pretrained=False, dataset="cameramen", arch="SirenModel_64_5")
    assert model is not None

In [21]:
if RUN_COLAB_CODE:
    for name, module in model.named_modules():
        print(name)

In [22]:
if RUN_COLAB_CODE:
    def init_jupyter_default_args(args):
        args.output_dir = '/content/' # None # 
        args.evaluate = False
        args.seed = 0
        args.deterministic = True
        args.cpu = False
        args.gpus = "0"
        args.load_serialized = False
        args.deprecated_resume = None
        args.resumed_checkpoint_path = None
        args.load_model_path = None
        args.reset_optimizer = False
        args.lr = args.momentum = args.weight_decay = 0.
        args.compress = '/content/distiller/examples/agp-pruning/siren64-5_schedule_agp.yaml'
        args.epochs = 0
        args.activation_stats = list()
        args.batch_size = 1
        args.workers = 1
        args.validation_split = 0.1
        args.effective_train_size = args.effective_valid_size = args.effective_test_size = 1.
        args.log_params_histograms = False
        args.print_freq = 10
        args.masks_sparsity = False
        args.display_confusion = False
        args.num_best_scores = 1
        args.name = ""
        args.kd_policy = None
        # args.summary = "sparsity"
        args.qe_stats_file = None
        args.verbose = False
        return args


    def config_learner_args(args, arch, dataset, dataset_path, pretrained, adam_args, batch, epochs):
        args.arch = f"{arch}"
        args.dataset = f"{dataset}"
        args.data = ""
        args.pretrained = False
        args.lr = adam_args[0]
        args.momentum = adam_args[1]
        args.weight_decay = adam_args[2]
        args.batch_size = 1
        args.epochs = epochs
        return args

In [23]:
if RUN_COLAB_CODE:
    args = regressor.init_regressor_compression_arg_parser()
    args, unknownargs = args.parse_known_args()
    pprint(args)

In [24]:
if RUN_COLAB_CODE:
    args = init_jupyter_default_args(args)
    args.batch_size

In [25]:
if RUN_COLAB_CODE:
    args = config_learner_args(args, "SirenModel_64_5", "cameramen", "", False, (0.1, 0.0, 1e-4) , 1, 100)
    args.arch, args.epochs

In [26]:
if RUN_COLAB_CODE:
    args.arch, args.verbose, args.print_freq

In [27]:
if RUN_COLAB_CODE:
    app = regressor.SirenRegressorCompressor(args, script_dir=os.path.dirname("."))

In [28]:
# %load_ext tensorboard
# %reload_ext tensorboard

In [29]:
# !kill 6966

In [30]:
# %tensorboard --logdir /content/logs

In [31]:
# Run the training loop
if RUN_COLAB_CODE:
    perf_scores_history = app.run_training_loop()

In [32]:
if RUN_COLAB_CODE:
    print(perf_scores_history[-1])

### Base mode

In [33]:
if RUN_MAIN_SIREN_BASE:
    !python main.py \
        --logging_root '/content/results/cameramen' \
        --experiment_name 'train' \
        --sidelength 256 \
        --num_epochs 100 \
        --n_hf 64  \
        --n_hl 5 \
        --lambda_L_1 0 \
        --lambda_L_2 0.0001 \
        --epochs_til_ckpt 10 \
        --seed 0 \
        --cuda \
        --train \
        --evaluate \
        --dynamic_quant qint8 qfloat16 \
        --verbose 0
    pass

### App Mode

In [34]:
# if RUN_TENSORBOARD_UTIL and RUN_MAIN_SIREN_APP:
# %reload_ext tensorboard

In [35]:
# if RUN_TENSORBOARD_UTIL:
# %tensorboard --logdir {LOG_DIR}

In [None]:
if RUN_MAIN_SIREN_APP and PRUNE_MODEL_AGP:
    !python siren_main_app.py \
        --logging_root {LOGGING_ROOT} \
        --experiment_name {EXPERIMENT_NAME} \
        --sidelength {SIDELENGHT} \
        --num_epochs 475000 \
        --lr 0.0001 \
        --n_hf {N_HF} \
        --n_hl {N_HL} \
        --lambda_L_1 0 \
        --lambda_L_2 0 \
        --epochs_til_ckpt 850 \
        --num-best-scores 3 \
        --compress {PRUNING_SCHEDULER_FILE} \
        --seed 0 \
        --cuda \
        --train \
        --evaluate \
        --verbose 0 \
        --resume-from {STATE_DICT_MODEL_FILE} \
        --target_sparsity 40.0 \
        --toll_sparsity 2.0 \
        --patience_sparsity 1000 \
        --trail_epochs 1000 \
        --mid_target_sparsities 5 10 20 25 30 35 40
    pass

Could not find the logger configuration file (/content/drive/MyDrive/Siren-Deep-Learning-Analyses/results/cameramen/agp_pruning/logging.conf) - using default logger configuration
Log file for this run: /content/drive/MyDrive/Siren-Deep-Learning-Analyses/results/cameramen/agp_pruning/___2020.12.15-093035/___2020.12.15-093035.log
Random seed: 0

--------------------------------------------------------
Logging to TensorBoard - remember to execute the server:
> tensorboard --logdir='./logs'

=> created a SirenCompressingModel model with the cameramen dataset
=> loading checkpoint /content/_mid_ckpt_epoch_299999.pth.tar
=> Checkpoint contents:
+----------------------+-------------+-----------------------+
| Key                  | Type        | Value                 |
|----------------------+-------------+-----------------------|
| arch                 | str         | SirenCompressingModel |
| compression_sched    | dict        |                       |
| dataset              | str         |

In [None]:
if RUN_MAIN_SIREN_APP and SENSITIVITY_PRUNING:
    # {NUM_EPOCHS} \
    !python siren_main_app.py \
            --logging_root {LOGGING_ROOT} \
            --experiment_name {EXPERIMENT_NAME} \
            --sidelength {SIDELENGHT} \
            --num_epochs 450000 \
            --lr 0.001 \
            --n_hf {N_HF} \
            --n_hl {N_HL} \
            --lambda_L_1 0 \
            --lambda_L_2 0 \
            --epochs_til_ckpt 950 \
            --num-best-scores 5 \
            --compress {PRUNING_SCHEDULER_FILE} \
            --seed 0 \
            --cuda \
            --train \
            --evaluate \
            --verbose 0 \
            --resume-from {STATE_DICT_MODEL_FILE}
    # --reset-optimizer \
    # --exp-load-weights-from
    pass

In [None]:
if RUN_MAIN_SIREN_APP and LEVEL_PRUNING:
    !python siren_main_app.py \
        --logging_root {LOGGING_ROOT} \
        --experiment_name {EXPERIMENT_NAME} \
        --sidelength {SIDELENGHT} \
        --num_epochs 434999 \
        --lr 0.0001 \
        --n_hf {N_HF} \
        --n_hl {N_HL} \
        --lambda_L_1 0 \
        --lambda_L_2 0 \
        --epochs_til_ckpt 200 \
        --num-best-scores 5 \
        --compress {PRUNING_SCHEDULER_FILE} \
        --seed 0 \
        --cuda \
        --train \
        --evaluate \
        --verbose 0 \
        --resume-from {STATE_DICT_MODEL_FILE}

In [None]:
if RUN_MAIN_SIREN_APP and PLAIN_TRAINING:
    !python siren_main_app.py \
            --logging_root {LOGGING_ROOT} \
            --experiment_name 'train' \
            --sidelength {SIDELENGHT} \
            --num_epochs {NUM_EPOCHS} \
            --lr 0.001 \
            --n_hf {N_HF} \
            --n_hl {N_HL} \
            --lambda_L_1 0 \
            --lambda_L_2 0 \
            --epochs_til_ckpt 500 \
            --num-best-scores 5 \
            --seed 0 \
            --cuda \
            --train \
            --evaluate \
            --verbose 0

## Test Model

In [None]:
if EVAL_MODEL:
    !python siren_main_app.py \
            --logging_root {LOGGING_ROOT_EVAL} \
            --experiment_name 'train' \
            --sidelength {SIDELENGHT} \
            --n_hf {N_HF} \
            --n_hl {N_HL} \
            --seed 0 \
            --cuda \
            --evaluate \
            --verbose 0 \
            --exp-load-weights-from {STATE_DICT_MODEL_FILE}

Could not find the logger configuration file (/content/drive/MyDrive/Siren-Deep-Learning-Analyses/results/cameramen/eval-res/logging.conf) - using default logger configuration
Log file for this run: /content/drive/MyDrive/Siren-Deep-Learning-Analyses/results/cameramen/eval-res/___2020.12.05-104940/___2020.12.05-104940.log
Random seed: 0

--------------------------------------------------------
Logging to TensorBoard - remember to execute the server:
> tensorboard --logdir='./logs'

=> created a SirenCompressingModel model with the cameramen dataset
=> loading checkpoint /content/_mid_ckpt_epoch_299999.pth.tar
=> Checkpoint contents:
+----------------------+-------------+-----------------------+
| Key                  | Type        | Value                 |
|----------------------+-------------+-----------------------|
| arch                 | str         | SirenCompressingModel |
| compression_sched    | dict        |                       |
| dataset              | str         | cameramen             |
| epoch                | int         | 299999                |
| extras               | dict        |                       |
| is_parallel          | bool        | True                  |
| optimizer_state_dict | dict        |                       |
| optimizer_type       | type        | Adam                  |
| state_dict           | OrderedDict |                       |
+----------------------+-------------+-----------------------+

=> Checkpoint['extras'] contents:
+--------------------+---------+------------------+
| Key                | Type    |            Value |
|--------------------+---------+------------------|
| best_epoch         | int     | 296800           |
| best_mse           | float   |      3.07749e-05 |
| best_psnr_score    | float64 |     51.1405      |
| best_ssim_score    | float64 |      0.996148    |
| current_mse        | float   |      4.85854e-05 |
| current_psnr_score | float64 |      0.995518    |
+--------------------+---------+------------------+

Loaded compression schedule from checkpoint (epoch 299999)
=> loaded 'state_dict' from checkpoint '/content/_mid_ckpt_epoch_299999.pth.tar'
Dataset sizes:
	test=1
--- test ---------------------
==> Loss: 0.0000486   PSNR: 49.1567465   SSIM: 0.9955182


Log file for this run: /content/drive/MyDrive/Siren-Deep-Learning-Analyses/results/cameramen/eval-res/___2020.12.05-104940/___2020.12.05-104940.log


In [None]:
if EVAL_MODEL:
    !python siren_main_app.py \
            --logging_root {LOGGING_ROOT_EVAL} \
            --experiment_name 'train' \
            --sidelength {SIDELENGHT} \
            --n_hf {N_HF} \
            --n_hl {N_HL} \
            --seed 0 \
            --cuda \
            --summary 'sparsity' \
            --compress {PRUNING_SCHEDULER_FILE} \
            --verbose 0 \
            --exp-load-weights-from {STATE_DICT_MODEL_FILE}

In [None]:
if EVAL_MODEL and LEVEL_PRUNING:
    !python siren_main_app.py \
            --logging_root {LOGGING_ROOT_EVAL} \
            --experiment_name 'train' \
            --sidelength {SIDELENGHT} \
            --n_hf {N_HF} \
            --n_hl {N_HL} \
            --seed 0 \
            --cuda \
            --evaluate \
            --compress {PRUNING_SCHEDULER_FILE} \
            --verbose 0 \
            --exp-load-weights-from {STATE_DICT_MODEL_FILE}