##### Copyright 2020 Google LLC.

Licensed under the Apache License, Version 2.0 (the "License");

In [None]:
# Download the official ResNet50 implementation and other libraries.
# the ResNet50 module s.t. we can use the model builders for our counting.
%%bash 
test -d tpu || git clone https://github.com/tensorflow/tpu tpu && mv tpu/models/experimental/resnet50_keras ./ 
test -d rigl || git clone https://github.com/google-research/rigl rigl_repo && mv rigl_repo/rigl ./ 
test -d gresearch || git clone https://github.com/google-research/google-research google_research

In [None]:
import numpy as np
import tensorflow as tf
from micronet_challenge import counting
from resnet50_keras import resnet_model as resnet_keras
from rigl import sparse_utils
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)


In [None]:
tf.compat.v1.reset_default_graph()
model = resnet_keras.ResNet50(1000)

In [None]:
masked_layers = []
for layer in model.layers:
  if isinstance(layer, (tf.keras.layers.Conv2D, tf.keras.layers.Dense)):
    masked_layers.append(layer)


In [None]:
PARAM_SIZE=32 # bits
import functools
get_stats = functools.partial(
    sparse_utils.get_stats, first_layer_name='conv1', last_layer_name='fc1000',
    param_size=PARAM_SIZE)
def print_stats(masked_layers, default_sparsity=0.8, method='erdos_renyi',
                custom_sparsities={}, is_debug=False, width=1., **kwargs):
  print('Method: %s, Sparsity: %f' % (method, default_sparsity))
  total_flops, total_param_bits, sparsity = get_stats(
      masked_layers, default_sparsity=default_sparsity, method=method,
      custom_sparsities=custom_sparsities, is_debug=is_debug, width=width, **kwargs)
  print('Total Flops: %.3f MFlops' % (total_flops/1e6))
  print('Total Size: %.3f Mbytes' % (total_param_bits/8e6))
  print('Real Sparsity: %.3f' % (sparsity))

# Pruning FLOPs
We calculate theoratical FLOPs for pruning, which means we will start counting sparse FLOPs when the pruning starts.

In [None]:
p_start, p_end, p_freq = 10000,25000,1000
target_sparsity = 0.8
total_flops = []
for i in range(0,32001,1000):
  if i < p_start:
    sparsity = 0.
  elif p_end < i:
    sparsity = target_sparsity
  else:
    sparsity = (1-(1-(i-p_start)/float(p_end-p_start))**3)*target_sparsity
  # print(i, sparsity)
  c_flops, _, _ = get_stats(
      masked_layers, default_sparsity=sparsity, method='random', custom_sparsities={'conv1/kernel:0':0, 'fc1000/kernel:0':0.8})
  # print(i, c_flops, sparsity)
  total_flops.append(c_flops)
avg_flops = sum(total_flops) / len(total_flops)
print('Average Flops: ', avg_flops, avg_flops/total_flops[0])

### Printing sparse network stats.

In [None]:
print_stats(masked_layers, 0.8, 'erdos_renyi_kernel', is_debug=True, erk_power_scale=0.2)
print_stats(masked_layers, 0.8, 'erdos_renyi')
print_stats(masked_layers, 0.8, 'random', {'conv1/kernel:0':0, 'fc1000/kernel:0':0.8}, is_debug=False)
print_stats(masked_layers, 0, 'random', is_debug=False)

In [None]:
print_stats(masked_layers, 0.9, 'erdos_renyi_kernel', is_debug=False)
print_stats(masked_layers, 0.9, 'erdos_renyi')
print_stats(masked_layers, 0.9, 'random', {'conv1/kernel:0':0., 'fc1000/kernel:0':0.9}, is_debug=False)


In [None]:
print_stats(masked_layers, 0.95, 'erdos_renyi_kernel', is_debug=False)
print_stats(masked_layers, 0.95, 'erdos_renyi')
print_stats(masked_layers, 0.95, 'random', {'conv1/kernel:0':0}, is_debug=False)

In [None]:
print_stats(masked_layers, 0.965, 'erdos_renyi_kernel', {'conv1/kernel:0':0}, is_debug=False)
print_stats(masked_layers, 0.965, 'erdos_renyi')
print_stats(masked_layers, 0.965, 'random', {'conv1/kernel:0':0}, is_debug=False)


## Finding the width Multiplier for small dense model

In [None]:
_, sparse_bits, _ = get_stats(masked_layers, 0.8, 'erdos_renyi_kernel')
_, bits, _ = get_stats(masked_layers, 0., 'erdos_renyi_kernel', width=0.465)
print(sparse_bits/bits)
print_stats(masked_layers, 0., 'erdos_renyi_kernel', is_debug=False, width=0.465)
print_stats(masked_layers, 0.8, 'erdos_renyi_kernel', is_debug=False, width=1)

In [None]:
_, sparse_bits, _ = get_stats(masked_layers, 0.9, 'erdos_renyi_kernel')
_, bits, _ = get_stats(masked_layers, 0., 'erdos_renyi_kernel', width=0.34)
print(sparse_bits/bits)
print_stats(masked_layers, 0., 'erdos_renyi_kernel', is_debug=False, width=0.34)
print_stats(masked_layers, 0.9, 'erdos_renyi_kernel', is_debug=False, width=1)

In [None]:
_, sparse_bits, _ = get_stats(masked_layers, 0.95, 'erdos_renyi_kernel')
_, bits, _ = get_stats(masked_layers, 0., 'erdos_renyi_kernel', width=0.26)
print(sparse_bits/bits)
print_stats(masked_layers, 0., 'erdos_renyi_kernel', is_debug=False, width=0.26)
print_stats(masked_layers, 0.95, 'erdos_renyi_kernel', is_debug=False, width=1)

In [None]:
_, sparse_bits, _ = get_stats(masked_layers, 0.965, 'erdos_renyi_kernel')
_, bits, _ = get_stats(masked_layers, 0., 'erdos_renyi_kernel', width=0.231)
print(sparse_bits/bits)
print_stats(masked_layers, 0., 'erdos_renyi_kernel', is_debug=False, width=0.231)
print_stats(masked_layers, 0.965, 'erdos_renyi_kernel', is_debug=False, width=1)

### Printing the Big-Sparse Results

In [None]:
# BIGGER
_, sparse_bits, _ = get_stats(masked_layers, 0., 'erdos_renyi_kernel')
_, bits, _ = get_stats(masked_layers, 0.8, 'erdos_renyi_kernel', width=2.1)
print(sparse_bits/bits)
print_stats(masked_layers, 0.8, 'erdos_renyi_kernel', is_debug=False, width=2.1)
print_stats(masked_layers, 0.8, 'random',  {'conv1/kernel:0':0, 'fc1000/kernel:0':0.8},
            is_debug=False, width=2.1)
print_stats(masked_layers, 0., 'erdos_renyi_kernel', is_debug=False, width=2.1)
print_stats(masked_layers, 0., 'erdos_renyi_kernel', is_debug=False, width=1)

In [None]:
_, sparse_bits, _ = get_stats(masked_layers, 0., 'erdos_renyi_kernel')
_, bits, _ = get_stats(masked_layers, 0.9, 'erdos_renyi_kernel', width=2.8)
print(sparse_bits/bits)
print_stats(masked_layers, 0.9, 'erdos_renyi_kernel', is_debug=False, width=2.8)
print_stats(masked_layers, 0.9, 'random',  {'conv1/kernel:0':0, 'fc1000/kernel:0':0.8}, is_debug=False, width=2.8)
print_stats(masked_layers, 0., 'erdos_renyi_kernel', is_debug=False, width=2.8)
print_stats(masked_layers, 0., 'erdos_renyi_kernel', is_debug=False, width=1)

## [BONUS] DSR FLOPs
Obtained from figure https://arxiv.org/abs/1902.05967; exact values are probably slightly different.



In [None]:
resnet_layers=['conv1/kernel:0',
'res2a_branch2a/kernel:0',
'res2a_branch2b/kernel:0',
'res2a_branch2c/kernel:0',
'res2a_branch1/kernel:0',
'res2b_branch2a/kernel:0',
'res2b_branch2b/kernel:0',
'res2b_branch2c/kernel:0',
'res2c_branch2a/kernel:0',
'res2c_branch2b/kernel:0',
'res2c_branch2c/kernel:0',
'res3a_branch2a/kernel:0',
'res3a_branch2b/kernel:0',
'res3a_branch2c/kernel:0',
'res3a_branch1/kernel:0',
'res3b_branch2a/kernel:0',
'res3b_branch2b/kernel:0',
'res3b_branch2c/kernel:0',
'res3c_branch2a/kernel:0',
'res3c_branch2b/kernel:0',
'res3c_branch2c/kernel:0',
'res3d_branch2a/kernel:0',
'res3d_branch2b/kernel:0',
'res3d_branch2c/kernel:0',
'res4a_branch2a/kernel:0',
'res4a_branch2b/kernel:0',
'res4a_branch2c/kernel:0',
'res4a_branch1/kernel:0',
'res4b_branch2a/kernel:0',
'res4b_branch2b/kernel:0',
'res4b_branch2c/kernel:0',
'res4c_branch2a/kernel:0',
'res4c_branch2b/kernel:0',
'res4c_branch2c/kernel:0',
'res4d_branch2a/kernel:0',
'res4d_branch2b/kernel:0',
'res4d_branch2c/kernel:0',
'res4e_branch2a/kernel:0',
'res4e_branch2b/kernel:0',
'res4e_branch2c/kernel:0',
'res4f_branch2a/kernel:0',
'res4f_branch2b/kernel:0',
'res4f_branch2c/kernel:0',
'res5a_branch2a/kernel:0',
'res5a_branch2b/kernel:0',
'res5a_branch2c/kernel:0',
'res5a_branch1/kernel:0',
'res5b_branch2a/kernel:0',
'res5b_branch2b/kernel:0',
'res5b_branch2c/kernel:0',
'res5c_branch2a/kernel:0',
'res5c_branch2b/kernel:0',
'res5c_branch2c/kernel:0',
'fc1000/kernel:0']
dsr_sparsities8=[0,
            0., .15, .5, .425, .575, .55, .425, .32, .44, .15,
            0., .15, .55, .6, .8, .65, .75, .65, .65, .65, .55, .65, .7,
            0., .35, .65, .85, .9, .8, .85, .85, .8, .85, .85, .85, .85, .8, .8, .9, .75, .8, .85,
            0., .65, .85, .95, .85, .8, .9, .65, .9, .8,
            .8]
dsr_sparsities9=[0,
            0., .4, .6, .65, .65, .6, .6, .5, .6, .45,
            0., .4, .7, .8, .9, .8, .85, .8, .75, .8, .7, .8, .8,
            0., .6, .8, .95, .95, .9, .95, .9, .9, .95, .9, .9, .95, .9, .9, .95, .85, .85, .9,
            0., 0.8, .95, .95, .9, .9, .95, .8, .95, .9,
            .9] 

In [None]:
dsr_map = dict(zip(resnet_layers, dsr_sparsities8))
print_stats(masked_layers, 0., 'random', dsr_map, is_debug=False, width=1)

In [None]:
dsr_map = dict(zip(resnet_layers, dsr_sparsities9))
print_stats(masked_layers, 0., 'random', dsr_map, is_debug=False, width=1)

# [BONUS] STR FLOPs
Layerwise sparsities are obtained from the [STR paper](https://arxiv.org/abs/2002.03231).

In [None]:
str_sparsities = """
Layer 1 - conv1 9408 118013952 51.46 51.40 63.02 59.80 59.83 64.87 67.36 66.96 72.11 69.46 73.29 73.47 72.05 75.12 76.12 77.75
Layer 2 - layer1.0.conv1 4096 12845056 69.36 73.24 87.57 83.28 85.18 89.60 91.41 91.11 92.38 91.75 94.46 94.51 94.60 95.95 96.53 96.51
Layer 3 - layer1.0.conv2 36864 115605504 77.85 76.26 90.87 89.48 87.31 94.79 94.27 95.04 95.69 96.07 97.36 97.77 98.35 98.51 98.59 98.84
Layer 4 - layer1.0.conv3 16384 51380224 74.81 74.65 86.52 85.80 85.25 91.85 92.78 93.67 94.13 94.69 96.61 97.03 97.37 98.04 98.21 98.47
Layer 5 - layer1.0.downsample.0 16384 51380224 70.95 72.96 83.53 83.34 82.56 89.13 90.62 90.17 91.83 92.69 95.48 94.89 95.68 96.98 97.56 97.72
Layer 6 - layer1.1.conv1 16384 51380224 80.27 79.58 89.82 89.89 88.51 94.56 96.64 95.78 95.81 96.81 98.79 98.90 98.98 99.13 99.62 99.47
Layer 7 - layer1.1.conv2 36864 115605504 81.36 80.95 91.75 90.60 89.61 94.70 95.78 96.18 96.42 97.26 98.65 99.07 99.40 99.11 99.31 99.56
Layer 8 - layer1.1.conv3 16384 51380224 84.45 80.11 91.22 91.70 90.21 95.17 97.05 95.81 96.34 97.23 98.68 98.76 98.90 99.16 99.57 99.46
Layer 9 - layer1.2.conv1 16384 51380224 78.23 79.79 90.12 88.07 89.36 94.62 95.94 94.74 96.23 96.75 97.96 98.41 98.72 99.38 99.35 99.46
Layer 10 - layer1.2.conv2 36864 115605504 76.01 81.53 91.06 87.03 88.27 93.90 95.63 94.26 96.24 96.11 97.54 98.27 98.44 99.32 99.19 99.39
Layer 11 - layer1.2.conv3 16384 51380224 84.47 83.28 94.95 90.99 92.64 95.76 96.95 96.01 96.87 97.31 98.38 98.60 98.72 99.38 99.27 99.51
Layer 12 - layer2.0.conv1 32768 102760448 73.74 73.96 86.78 85.95 85.90 92.32 94.79 93.86 94.62 95.64 97.19 98.22 98.52 98.48 98.84 98.92
Layer 13 - layer2.0.conv2 147456 115605504 82.56 85.70 91.31 93.91 94.03 97.54 97.43 97.65 98.38 98.62 99.24 99.23 99.40 99.61 99.67 99.63
Layer 14 - layer2.0.conv3 65536 51380224 84.70 83.55 93.04 93.13 92.13 96.61 97.37 97.21 97.59 98.14 98.80 98.95 99.18 99.29 99.47 99.43
Layer 15 - layer2.0.downsample.0 131072 102760448 85.10 87.66 92.78 94.96 95.13 98.07 97.97 98.15 98.70 98.88 99.37 99.35 99.40 99.69 99.68 99.71
Layer 16 - layer2.1.conv1 65536 51380224 85.42 85.79 94.04 95.31 94.94 97.92 98.53 98.21 98.84 99.06 99.46 99.53 99.72 99.78 99.81 99.80
Layer 17 - layer2.1.conv2 147456 115605504 76.95 82.75 87.63 91.50 91.76 95.59 97.22 96.07 97.32 97.80 98.24 98.24 98.60 99.24 99.66 99.33
Layer 18 - layer2.1.conv3 65536 51380224 84.76 84.71 93.10 93.66 93.23 97.00 98.18 97.35 98.06 98.41 98.96 99.21 99.32 99.55 99.58 99.59
Layer 19 - layer2.2.conv1 65536 51380224 84.30 85.34 92.70 94.61 94.76 97.72 97.91 98.21 98.54 98.98 99.24 99.35 99.50 99.62 99.63 99.77
Layer 20 - layer2.2.conv2 147456 115605504 84.28 85.43 92.99 94.86 94.90 97.52 97.21 98.11 98.19 99.04 99.28 99.37 99.46 99.63 99.59 99.72
Layer 21 - layer2.2.conv3 65536 51380224 82.19 84.21 91.12 93.38 93.53 96.89 97.14 97.59 97.77 98.66 98.96 99.15 99.25 99.49 99.51 99.57
Layer 22 - layer2.3.conv1 65536 51380224 83.37 84.41 90.46 93.26 93.50 96.71 97.89 96.99 98.14 98.36 99.10 99.23 99.33 99.53 99.75 99.60
Layer 23 - layer2.3.conv2 147456 115605504 82.83 84.03 91.44 93.21 93.25 96.83 98.02 96.96 98.45 98.30 98.97 99.06 99.26 99.31 99.81 99.68
Layer 24 - layer2.3.conv3 65536 51380224 82.93 85.65 91.02 94.14 93.56 97.20 97.97 97.04 98.16 98.36 98.88 98.97 99.20 99.32 99.67 99.62
Layer 25 - layer3.0.conv1 131072 102760448 76.63 77.98 85.99 88.85 88.60 94.26 95.07 94.97 96.21 96.59 97.75 98.04 98.30 98.72 99.11 99.06
Layer 26 - layer3.0.conv2 589824 115605504 87.35 88.68 94.39 96.14 96.19 98.51 98.77 98.72 99.11 99.23 99.53 99.59 99.64 99.73 99.80 99.81
Layer 27 - layer3.0.conv3 262144 51380224 81.22 83.22 90.58 93.19 93.05 96.82 97.38 97.32 97.98 98.28 98.88 99.03 99.16 99.39 99.55 99.53
Layer 28 - layer3.0.downsample.0 524288 102760448 89.75 90.99 96.05 97.20 97.16 98.96 99.21 99.20 99.50 99.58 99.78 99.82 99.86 99.91 99.94 99.93
Layer 29 - layer3.1.conv1 262144 51380224 85.88 87.35 93.43 95.36 96.12 98.64 98.77 98.87 99.22 99.33 99.64 99.67 99.72 99.82 99.88 99.84
Layer 30 - layer3.1.conv2 589824 115605504 85.06 86.24 92.74 95.06 95.30 98.09 98.28 98.36 98.75 99.08 99.46 99.48 99.54 99.69 99.76 99.76
Layer 31 - layer3.1.conv3 262144 51380224 84.34 86.79 92.15 94.84 94.90 97.75 98.15 98.11 98.56 98.94 99.30 99.36 99.45 99.65 99.79 99.70
Layer 32 - layer3.2.conv1 262144 51380224 87.51 89.15 94.15 96.77 96.46 98.81 98.83 98.96 99.19 99.44 99.67 99.71 99.74 99.82 99.85 99.89
Layer 33 - layer3.2.conv2 589824 115605504 87.15 88.67 94.09 95.59 96.14 98.86 98.69 98.91 99.21 99.20 99.64 99.72 99.76 99.85 99.84 99.90
Layer 34 - layer3.2.conv3 262144 51380224 84.86 86.90 92.40 94.99 94.99 98.19 98.19 98.42 98.76 98.97 99.42 99.56 99.62 99.76 99.75 99.88
Layer 35 - layer3.3.conv1 262144 51380224 86.62 89.46 94.06 96.08 95.88 98.70 98.71 98.77 99.01 99.27 99.58 99.66 99.69 99.83 99.87 99.87
Layer 36 - layer3.3.conv2 589824 115605504 86.52 87.97 93.56 96.10 96.11 98.70 98.82 98.89 99.19 99.31 99.68 99.73 99.77 99.88 99.87 99.93
Layer 37 - layer3.3.conv3 262144 51380224 84.19 86.81 92.32 94.94 94.91 98.20 98.37 98.43 98.82 99.00 99.51 99.57 99.64 99.81 99.81 99.87
Layer 38 - layer3.4.conv1 262144 51380224 85.85 88.40 93.55 95.49 95.86 98.35 98.44 98.55 98.79 98.96 99.54 99.59 99.60 99.82 99.86 99.87
Layer 39 - layer3.4.conv2 589824 115605504 85.96 87.38 93.27 95.66 95.63 98.41 98.58 98.56 99.19 99.26 99.64 99.69 99.67 99.87 99.90 99.92
Layer 40 - layer3.4.conv3 262144 51380224 83.45 85.76 91.75 94.49 94.35 97.67 98.09 97.99 98.65 98.94 99.49 99.52 99.48 99.77 99.86 99.85
Layer 41 - layer3.5.conv1 262144 51380224 83.33 85.77 91.79 95.09 94.24 97.46 97.89 97.92 98.71 98.90 99.35 99.52 99.58 99.76 99.79 99.83
Layer 42 - layer3.5.conv2 589824 115605504 84.98 86.67 92.48 94.92 95.13 97.88 98.14 98.32 98.91 99.00 99.44 99.58 99.69 99.80 99.83 99.87
Layer 43 - layer3.5.conv3 262144 51380224 79.78 82.23 89.39 93.14 92.76 96.59 97.04 97.30 98.10 98.41 99.03 99.25 99.44 99.61 99.71 99.75
Layer 44 - layer4.0.conv1 524288 102760448 77.83 79.61 87.11 90.32 90.64 95.39 95.84 95.92 97.17 97.35 98.36 98.60 98.83 99.20 99.37 99.42
Layer 45 - layer4.0.conv2 2359296 115605504 86.18 88.00 93.53 95.66 95.78 98.31 98.47 98.55 99.08 99.16 99.54 99.63 99.69 99.81 99.85 99.86
Layer 46 - layer4.0.conv3 1048576 51380224 78.43 80.48 87.85 91.14 91.27 96.00 96.40 96.47 97.53 97.92 98.81 99.00 99.15 99.45 99.57 99.61
Layer 47 - layer4.0.downsample.0 2097152 102760448 88.49 89.98 95.03 96.79 96.90 98.91 99.06 99.11 99.45 99.51 99.77 99.82 99.85 99.92 99.94 99.94
Layer 48 - layer4.1.conv1 1048576 51380224 82.07 84.02 90.34 93.69 93.72 97.15 97.56 97.76 98.45 98.75 99.27 99.36 99.54 99.67 99.76 99.80
Layer 49 - layer4.1.conv2 2359296 115605504 83.42 85.23 91.16 93.98 93.93 97.26 97.58 97.71 98.36 98.67 99.25 99.34 99.50 99.68 99.76 99.80
Layer 50 - layer4.1.conv3 1048576 51380224 78.08 79.96 86.66 90.48 90.22 95.22 95.76 95.89 96.88 97.65 98.70 98.85 99.13 99.45 99.58 99.66
Layer 51 - layer4.2.conv1 1048576 51380224 76.34 77.93 84.98 87.57 88.47 93.90 93.87 94.16 95.55 95.91 97.66 97.97 98.15 98.88 99.08 99.22
Layer 52 - layer4.2.conv2 2359296 115605504 73.57 74.97 82.32 84.37 86.01 91.92 91.66 92.22 94.02 94.16 96.65 97.13 97.29 98.44 98.74 99.00
Layer 53 - layer4.2.conv3 1048576 51380224 68.78 70.38 78.11 80.29 81.73 89.64 89.43 89.65 91.40 92.65 96.02 96.72 96.93 98.47 98.83 99.15
Layer 54 - fc 2048000 2048000 50.65 52.46 60.48 64.50 65.12 75.20 75.73 75.80 78.57 80.69 85.96 87.26 88.03 91.11 92.15 92.87"""

In [None]:
resnet_layers=['conv1/kernel:0',
'res2a_branch2a/kernel:0',
'res2a_branch2b/kernel:0',
'res2a_branch2c/kernel:0',
'res2a_branch1/kernel:0',
'res2b_branch2a/kernel:0',
'res2b_branch2b/kernel:0',
'res2b_branch2c/kernel:0',
'res2c_branch2a/kernel:0',
'res2c_branch2b/kernel:0',
'res2c_branch2c/kernel:0',
'res3a_branch2a/kernel:0',
'res3a_branch2b/kernel:0',
'res3a_branch2c/kernel:0',
'res3a_branch1/kernel:0',
'res3b_branch2a/kernel:0',
'res3b_branch2b/kernel:0',
'res3b_branch2c/kernel:0',
'res3c_branch2a/kernel:0',
'res3c_branch2b/kernel:0',
'res3c_branch2c/kernel:0',
'res3d_branch2a/kernel:0',
'res3d_branch2b/kernel:0',
'res3d_branch2c/kernel:0',
'res4a_branch2a/kernel:0',
'res4a_branch2b/kernel:0',
'res4a_branch2c/kernel:0',
'res4a_branch1/kernel:0',
'res4b_branch2a/kernel:0',
'res4b_branch2b/kernel:0',
'res4b_branch2c/kernel:0',
'res4c_branch2a/kernel:0',
'res4c_branch2b/kernel:0',
'res4c_branch2c/kernel:0',
'res4d_branch2a/kernel:0',
'res4d_branch2b/kernel:0',
'res4d_branch2c/kernel:0',
'res4e_branch2a/kernel:0',
'res4e_branch2b/kernel:0',
'res4e_branch2c/kernel:0',
'res4f_branch2a/kernel:0',
'res4f_branch2b/kernel:0',
'res4f_branch2c/kernel:0',
'res5a_branch2a/kernel:0',
'res5a_branch2b/kernel:0',
'res5a_branch2c/kernel:0',
'res5a_branch1/kernel:0',
'res5b_branch2a/kernel:0',
'res5b_branch2b/kernel:0',
'res5b_branch2c/kernel:0',
'res5c_branch2a/kernel:0',
'res5c_branch2b/kernel:0',
'res5c_branch2c/kernel:0',
'fc1000/kernel:0']

In [None]:
from collections import defaultdict
str_sparsities_parsed = defaultdict(list)
for j, l in enumerate(str_sparsities.strip().split('\n')):
  l = l.split('-')[1].strip().split(' ')
  if l[0] == 'Overall':
    overall_sparsities = map(float, l[3:])
  else:
    for i, ls in enumerate(l[3:]):
      s = overall_sparsities[i]
      # Accuracies are between 0 and 1, so devide by 100.
      str_sparsities_parsed[s].append(float(ls) / 100.)

In [None]:
for k in str_sparsities_parsed:
  print(k)
  dsr_map = dict(zip(resnet_layers, str_sparsities_parsed[k]))
  print_stats(masked_layers, 0., 'random', dsr_map, is_debug=False, width=1)