In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os, sys

import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F

proj_root = os.path.dirname(os.path.dirname(os.path.abspath(".")))
sys.path.append(proj_root)

from minatar import Environment

from minatar_dqn.my_dqn import Conv_QNET

from experiments.experiment_utils import (
    seed_everything,
    search_files_containing_string,
    split_path_at_substring,
    collect_training_output_files,
)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Collect all paths to models in a specified folder
file_dir = os.path.dirname(os.path.abspath("."))
training_outputs_folder_path = os.path.join(proj_root, "experiments", "training", "outputs")
pruning_outputs_folder_path = os.path.join(file_dir, "outputs")
training_timestamp_folder = "2023_05_15-18_23_40"

experiments_folder = os.path.join(training_outputs_folder_path, training_timestamp_folder)


In [7]:
experiment_paths = collect_training_output_files(
        os.path.join(training_outputs_folder_path, training_timestamp_folder)
    )

for exp in experiment_paths:
    exp["models_folder_path"] = os.path.dirname(exp["model_path"])

In [8]:
experiment_paths

[{'model_path': 'd:\\Work\\PhD\\minatar_work\\experiments\\training\\outputs\\2023_05_15-18_23_40\\conv16_lin64\\breakout\\3\\model_checkpoints\\mck_20',
  'training_folder_path': 'd:\\Work\\PhD\\minatar_work\\experiments\\training\\outputs\\2023_05_15-18_23_40\\conv16_lin64\\breakout\\3',
  'config_path': 'd:\\Work\\PhD\\minatar_work\\experiments\\training\\outputs\\2023_05_15-18_23_40\\conv16_lin64\\breakout\\3\\conv16_lin64_breakout_3_config',
  'stats_path': 'd:\\Work\\PhD\\minatar_work\\experiments\\training\\outputs\\2023_05_15-18_23_40\\conv16_lin64\\breakout\\3\\conv16_lin64_breakout_3_train_stats',
  'models_folder_path': 'd:\\Work\\PhD\\minatar_work\\experiments\\training\\outputs\\2023_05_15-18_23_40\\conv16_lin64\\breakout\\3\\model_checkpoints'},
 {'model_path': 'd:\\Work\\PhD\\minatar_work\\experiments\\training\\outputs\\2023_05_15-18_23_40\\conv32_lin128\\breakout\\3\\model_checkpoints\\mck_20',
  'training_folder_path': 'd:\\Work\\PhD\\minatar_work\\experiments\\traini

For each checkpoint of an experiment we want to prune with different pruning factors and compare against different thresholds + redo scores

In [17]:
exp_paths = experiment_paths[0]

exp_paths["stats_path"]

training_stats_data = torch.load(exp_paths["stats_path"])
redo_scores = training_stats_data["redo_scores"]["policy"]
len(redo_scores)

21

In [18]:
redo_scores[0]

[tensor([0.0220, 0.0466, 0.0358, 0.0445, 0.0747, 0.0760, 0.0832, 0.0382, 0.0473,
         0.0343, 0.1171, 0.0418, 0.0772, 0.1289, 0.1208, 0.0115]),
 tensor([0.0319, 0.0292, 0.0653, 0.0699, 0.0310, 0.0837, 0.0540, 0.0850, 0.0484,
         0.0311, 0.0333, 0.1288, 0.1299, 0.0699, 0.0915, 0.0173]),
 tensor([0.0000, 0.0557, 0.0000, 0.0041, 0.0000, 0.0000, 0.0183, 0.0000, 0.0735,
         0.0797, 0.0000, 0.0331, 0.1031, 0.0000, 0.0000, 0.0000, 0.0022, 0.0140,
         0.0342, 0.0320, 0.0000, 0.0085, 0.0129, 0.0000, 0.0630, 0.0000, 0.0000,
         0.0086, 0.0120, 0.0000, 0.0000, 0.0098, 0.0887, 0.0078, 0.0055, 0.0533,
         0.0024, 0.0000, 0.0000, 0.0128, 0.0000, 0.0000, 0.0000, 0.0179, 0.0006,
         0.0190, 0.0122, 0.0124, 0.0000, 0.0124, 0.0043, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0823, 0.0426, 0.0000, 0.0000, 0.0463, 0.0000, 0.0033, 0.0000,
         0.0114])]

In [21]:
redo_scores[20]

[tensor([0.0326, 0.0280, 0.0504, 0.0595, 0.0562, 0.0483, 0.0714, 0.0667, 0.0425,
         0.0880, 0.0984, 0.0570, 0.0631, 0.0832, 0.1093, 0.0455]),
 tensor([0.0761, 0.0547, 0.0495, 0.0769, 0.0566, 0.0865, 0.0752, 0.0657, 0.0569,
         0.0691, 0.0589, 0.0425, 0.0630, 0.0495, 0.0692, 0.0498]),
 tensor([1.4707e-03, 6.4590e-02, 1.1963e-02, 1.9946e-03, 1.9096e-03, 2.7781e-04,
         3.6924e-02, 1.4560e-03, 6.9540e-02, 7.8125e-02, 5.0273e-04, 8.9667e-03,
         8.9484e-02, 5.6402e-04, 2.3150e-03, 1.3857e-02, 2.6151e-03, 2.7216e-03,
         3.2223e-02, 3.0299e-03, 0.0000e+00, 9.4393e-03, 3.4041e-03, 0.0000e+00,
         5.2601e-02, 1.3624e-02, 0.0000e+00, 4.2704e-03, 1.1238e-02, 2.9081e-03,
         1.7461e-03, 5.2460e-03, 8.2055e-02, 7.4827e-03, 1.5347e-02, 5.8933e-02,
         2.0834e-02, 6.9320e-07, 2.0289e-03, 1.8741e-03, 2.8097e-03, 3.2972e-05,
         2.9016e-03, 1.3664e-02, 5.2795e-04, 2.3168e-03, 3.1712e-03, 6.3628e-04,
         1.4953e-02, 1.8250e-02, 1.5099e-03, 8.6197e-05,