In [1]:
import torch
import torch.nn.functional as F
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.utils.tensorboard import SummaryWriter
import pytorch_lightning as pl
import random
import dotenv
import omegaconf
import hydra
import logging
import wandb
from datetime import date
import pathlib
from typing import Dict, Any
from copy import deepcopy

from rigl_torch.models.model_factory import ModelFactory
from rigl_torch.rigl_scheduler import RigLScheduler
from rigl_torch.rigl_constant_fan import RigLConstFanScheduler
from rigl_torch.datasets import get_dataloaders
from rigl_torch.optim import (
    get_optimizer,
    get_lr_scheduler,
)
from rigl_torch.utils.checkpoint import Checkpoint
from rigl_torch.utils.rigl_utils import get_T_end, get_fan_in_after_ablation, get_conv_idx_from_flat_idx


In [2]:
dataset = "cifar10"
model="wide_resnet22"
with hydra.initialize(config_path="../configs"):
    cfg = hydra.compose(
        config_name="config.yaml", 
        overrides=[
            f"dataset={dataset}",
            f"model={model}",
            "compute.distributed=False",
            "rigl.dense_allocation=0.01",
            "rigl.const_fan_in=True",
            "rigl.filter_ablation_threshold=0.01",
            "rigl.dynamic_ablation=True", 
            "rigl.static_ablation=False",
            "rigl.min_salient_weights_per_neuron=1",
            "rigl.delta=2",
            "rigl.grad_accumulation_n=1",
        ]
    )
cfg

The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  with hydra.initialize(config_path="../configs"):


{'dataset': {'name': 'cifar10', 'normalize': False, 'num_classes': 10, 'classes': ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'], 'train_len': 50000}, 'model': {'name': 'wide_resnet22'}, 'experiment': {'comment': 'dense_alloc-${rigl.dense_allocation}_weight_per_neuron-${rigl.min_salient_weights_per_neuron}', 'name': '${model.name}_${dataset.name}_${experiment.comment}', 'resume_from_checkpoint': False, 'run_id': None}, 'paths': {'base': '${oc.env:BASE_PATH}', 'data_folder': '${paths.base}/data', 'artifacts': '${paths.base}/artifacts', 'logs': '${paths.base}/logs', 'checkpoints': '${paths.artifacts}/checkpoints'}, 'rigl': {'dense_allocation': 0.01, 'delta': 2, 'grad_accumulation_n': 1, 'alpha': 0.3, 'static_topo': 0, 'const_fan_in': True, 'sparsity_distribution': 'erk', 'erk_power_scale': 1.0, 'use_t_end': True, 'static_ablation': False, 'dynamic_ablation': True, 'filter_ablation_threshold': 0.01, 'use_sparse_initialization': False, 'min_salie

In [1]:
# net = ModelFactory.load_model("wide_resnet22", "cifar10")
# net = ModelFactory.load_model("resnet50", "imagenet")
device = torch.device("cuda:0")
train_loader, test_loader = get_dataloaders(cfg)
model = ModelFactory.load_model(
        model=cfg.model.name, dataset=cfg.dataset.name
    )
model.to(device)
optimizer = get_optimizer(cfg, model, state_dict=None)
scheduler = get_lr_scheduler(cfg, optimizer, state_dict=None)
T_end = get_T_end(cfg, train_loader)
if cfg.rigl.const_fan_in:
    rigl_scheduler = RigLConstFanScheduler
else:
    rigl_scheduler = RigLScheduler
pruner = rigl_scheduler(
    model,
    optimizer,
    dense_allocation=cfg.rigl.dense_allocation,
    alpha=cfg.rigl.alpha,
    delta=cfg.rigl.delta,
    static_topo=cfg.rigl.static_topo,
    T_end=T_end,
    ignore_linear_layers=False,
    grad_accumulation_n=cfg.rigl.grad_accumulation_n,
    sparsity_distribution=cfg.rigl.sparsity_distribution,
    erk_power_scale=cfg.rigl.erk_power_scale,
    state_dict=None,
    filter_ablation_threshold=cfg.rigl.static_filter_ablation_threshold,
    static_ablation=cfg.rigl.static_ablation,
    dynamic_ablation=cfg.rigl.dynamic_ablation,
    min_salient_weights_per_neuron=cfg.rigl.min_salient_weights_per_neuron,
    )

NameError: name 'torch' is not defined

In [8]:
w = pruner.W[0]
score_grow_lifted = torch.abs(w)
n_fan_in = -1
idx_to_grow = torch.topk(
                    score_grow_lifted.flatten(),
                    k=n_fan_in,
                ).indices
idx_to_grow

RuntimeError: selected index k out of range

In [7]:
t = torch.Tensor([[2,3],[2,3]])
torch.where(
    t==2,
    torch.ones_like(t, dtype=torch.bool),
    torch.zeros_like(t, dtype=torch.bool)
)

tensor([[ True, False],
        [ True, False]])

In [4]:
for batch_idx, (data,target) in enumerate(train_loader):
    data, target = data.to(device), target.to(device)
    logits = model(data)
    loss = F.cross_entropy(
        logits,
        target,
        label_smoothing=cfg.training.label_smoothing,
    )
    loss.backward()
    optimizer.step()
    if batch_idx > 200:
        break
    pruner()
    optimizer.zero_grad()

In [7]:
drop_fraction = pruner.cosine_annealing()
# if distributed these values will be populated
is_dist = dist.is_initialized()
world_size = dist.get_world_size() if is_dist else None

pruner.dynamically_ablated_neuron_idx = []
for idx, w in enumerate(pruner.W):
    # if sparsity is 0%, skip
    if pruner.S[idx] <= 0:
        continue

    # calculate raw scores
    score_drop = torch.abs(w)
    _max_score_drop = score_drop.max().item()

    # Set ablated filter drop scores to min of score_grow to avoid
    # pruning already inactive weights
    # TODO: Remove inital ablated filtering.
    score_drop[
        : pruner.static_ablated_filters[idx]
    ] = score_drop.min().item()

    score_grow = torch.abs(pruner.backward_hook_objects[idx].dense_grad)

    # Set ablated filter scores to min of score_grow to avoid regrowing
    score_grow[
        : pruner.static_ablated_filters[idx]
    ] = score_grow.min().item()

    current_mask = pruner.backward_masks[idx]
    break


In [8]:
w.shape

torch.Size([16, 3, 3, 3])

In [69]:
sorted_values, sorted_idx = torch.abs(w).flatten().sort(descending=True)
max_idx = sorted_idx[0]
max_idx

tensor(326, device='cuda:0')

In [12]:
max_value = sorted_values[0]
max_value

tensor(0.3056, device='cuda:0', grad_fn=<SelectBackward0>)

In [84]:
@torch.no_grad()
def _get_neurons_to_ablate(
    pruner, 
    score_drop: torch.Tensor,
    score_grow,
    n_keep,
    n_prune,
    n_total,
):
    if pruner.dynamic_ablation:
        neurons_to_ablate = []
        saliency_mask = torch.zeros(
            size=(score_drop.numel(),),
            dtype=torch.bool,
            device=score_drop.device,
        )
        _, keep_idx = score_drop.flatten().sort(descending=True)
        print(keep_idx[:n_keep])
        saliency_mask[keep_idx[:n_keep]] = True
        # print(f"Keep_idx from drop: {[i for i in saliency_mask if i.any()]}")

        _, grow_idx = score_grow.flatten().sort(descending=True)
        saliency_mask[grow_idx[:n_prune]] = True

        saliency_mask = saliency_mask.reshape(shape=score_drop.shape)
        for neuron_idx, neuron in enumerate(saliency_mask):
            if neuron.sum() < pruner.min_salient_weights_per_neuron:
                print(f"ablating neuron {neuron_idx}")
                neurons_to_ablate.append(neuron_idx)
        return neurons_to_ablate, saliency_mask
        
n_total = pruner.N[idx]
n_ones = torch.sum(current_mask).item()
n_prune = int(n_ones * drop_fraction)
n_keep = int(n_ones - n_prune)
n_non_zero_weights = torch.count_nonzero(score_drop).item()
if n_non_zero_weights < n_keep:
    # Then we don't have enough non-zero weights to keep. We keep
    # ALL non-zero weights in this scenario and readjust our keep /
    # prune amounts to suit
    n_keep = n_non_zero_weights
    n_prune = n_ones - n_keep

# Get neurons to ablate
neurons_to_ablate, saliency_mask = _get_neurons_to_ablate(
    pruner,
    score_drop=score_drop,
    score_grow=score_grow,
    n_keep=n_keep,
    n_prune=n_prune,
    n_total=n_total,
)


tensor([326, 328, 424,  12, 244, 199, 203, 427, 342, 247, 215,  17,  14, 350,
        192,  26, 420, 260, 411, 409, 333, 245, 202, 251, 250, 248, 256, 254,
        259, 257, 243, 253,  19, 193, 194, 207, 246,  20,  18, 197, 249, 196,
        198, 191, 190,  22,  25, 414, 349,  24, 415, 429, 347, 345, 200,  13,
         16], device='cuda:0')
ablating neuron 2
ablating neuron 3
ablating neuron 4
ablating neuron 5
ablating neuron 6
ablating neuron 8
ablating neuron 10
ablating neuron 13
ablating neuron 14


In [85]:
neurons_to_ablate

[2, 3, 4, 5, 6, 8, 10, 13, 14]

In [77]:
max_idx

tensor(326, device='cuda:0')

In [66]:
for f in saliency_mask:
    print(f.sum())

tensor(12, device='cuda:0')
tensor(7, device='cuda:0')
tensor(0, device='cuda:0')
tensor(0, device='cuda:0')
tensor(0, device='cuda:0')
tensor(0, device='cuda:0')
tensor(0, device='cuda:0')
tensor(14, device='cuda:0')
tensor(0, device='cuda:0')
tensor(15, device='cuda:0')
tensor(0, device='cuda:0')
tensor(16, device='cuda:0')
tensor(8, device='cuda:0')
tensor(0, device='cuda:0')
tensor(0, device='cuda:0')
tensor(8, device='cuda:0')


In [71]:
neurons_to_ablate

[2, 3, 4, 5, 6, 8, 10, 13, 14]

In [72]:
get_conv_idx_from_flat_idx(max_idx.item(), w.shape)

(12, 0, 0, 2)

In [86]:
pruner.dynamically_ablated_neuron_idx.append(neurons_to_ablate)
# print(f"neurons to ablate = {neurons_to_ablate}")
# print(f"len neurons to ablate = {len(neurons_to_ablate)}")
n_fan_in = get_fan_in_after_ablation(
    weight_tensor=w,
    num_neurons_to_ablate=len(neurons_to_ablate),
    sparsity=pruner.S[idx],
)

# create drop mask
drop_mask = pruner._get_drop_mask(
    score_drop,
    n_keep,
    neurons_to_ablate=neurons_to_ablate,
    n_fan_in=n_fan_in,
)

# create growth mask per filter
grow_mask = pruner._get_grow_mask(
    score_grow,
    drop_mask,
    n_fan_in,
    neurons_to_ablate,
)

# get new weights
new_weights = pruner._get_new_weights(w, current_mask, grow_mask)
w.data = new_weights

combined_mask = grow_mask + drop_mask
current_mask.data = combined_mask

pruner.reset_momentum()
pruner.apply_mask_to_weights()
pruner.apply_mask_to_gradients()
pruner._verify_neuron_ablation()

In [89]:
torch.abs(w).max()

tensor(0.3056, device='cuda:0', grad_fn=<MaxBackward1>)

In [20]:
from rigl_torch.utils.rigl_utils import get_W

W = get_W(model=model)
type(W[0])

torch.nn.parameter.Parameter

In [25]:
n_fan_in = 3

_, idx = w.topk(k=n_fan_in, largest=True, dim=0, sorted=False)
idx

tensor([[[[13,  9,  8],
          [ 3, 12, 10],
          [11, 11,  9]],

         [[ 0,  1,  0],
          [ 1,  1,  1],
          [ 0,  0,  0]],

         [[ 0,  2,  1],
          [ 1,  2,  0],
          [ 0,  1, 12]]],


        [[[ 0,  1, 13],
          [ 0, 15, 13],
          [ 0, 14, 11]],

         [[ 3,  7,  1],
          [ 3,  4,  7],
          [ 1,  1, 10]],

         [[ 4,  3,  0],
          [ 6,  3,  2],
          [ 2,  4, 13]]],


        [[[ 3,  3,  0],
          [ 1,  0,  1],
          [ 1,  1,  0]],

         [[ 6,  9,  7],
          [ 4,  6,  2],
          [ 2,  3,  2]],

         [[ 7,  5,  2],
          [ 7,  1,  3],
          [ 5,  6,  2]]]], device='cuda:0')

In [5]:
w = W[0]
w.shape
n_keep = int(w.numel()*0.1)
n_keep

43

In [164]:
item = 0.3347453474998474

In [13]:
v, i = torch.abs(w).flatten().sort(descending=True)
i
torch.abs(w).flatten()[i][:40]

tensor([0.3347, 0.3173, 0.3022, 0.2750, 0.2702, 0.2687, 0.2682, 0.2674, 0.1919,
        0.1640, 0.1605, 0.1596, 0.1591, 0.1573, 0.1544, 0.1524, 0.1522, 0.1480,
        0.1461, 0.1427, 0.1416, 0.1350, 0.1318, 0.1273, 0.1163, 0.1135, 0.1129,
        0.1128, 0.1062, 0.1062, 0.1047, 0.1025, 0.0999, 0.0956, 0.0897, 0.0894,
        0.0883, 0.0869, 0.0861, 0.0854], device='cuda:0',
       grad_fn=<SliceBackward0>)

In [31]:
s_w = (mask.to(device="cuda") * w)

In [35]:
s_w[]

tensor(0.3347, device='cuda:0', grad_fn=<MaxBackward1>)

In [27]:
mask = torch.zeros(size=(w.numel(),), dtype=torch.bool)
mask[i[:40]] = True
mask = mask.reshape(w.shape)
# mask = mask.reshape(-1)[i] = True
# mask.reshape(w.shape)

In [12]:
_, keep_mask_idx = torch.abs(w).flatten().topk(k=w.numel(), dim=-1, largest=True, sorted=False)
mask = (keep_mask_idx == 0).reshape(w.shape)

for idx, f in enumerate(mask):
    if f.any():
        print(idx)
# keep_mask= torch.zeros(size=w.shape, dtype=torch.bool).flatten()
# keep_mask[keep_mask_idx] = True
# keep_mask.sum()

3


In [21]:
values, idx = w.topk(k=-1,
                     dim=0)
idx.shape

RuntimeError: selected index k out of range

In [13]:
w.shape

torch.Size([16, 3, 3, 3])

In [18]:
W2 = _get_W(model=model)
W2 == W

True

In [13]:
len(pruner.W)

54

In [19]:
for i in model._modules.items():
    print(i)

('conv1', Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False))
('bn1', BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))
('relu', ReLU(inplace=True))
('maxpool', MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False))
('layer1', Sequential(
  (0): Bottleneck(
    (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (downsample): Sequential(
      (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): Ba

In [17]:
for better_way, pruner_way in zip(found_types, pruner.W):
    print(torch.eq(better_way.weight, pruner_way).all())

tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(Tru

In [18]:
pruner.__str__()

'RigLScheduler(\nlayers=54,\nnonzero_params=[384/9408, 576/4096, 640/36864, 1536/16384, 1536/16384, 1536/16384, 640/36864, 1536/16384, 1536/16384, 640/36864, 1536/16384, 1792/32768, 1200/147456, 3072/65536, 3584/131072, 3072/65536, 1200/147456, 3072/65536, 3072/65536, 1200/147456, 3072/65536, 3072/65536, 1200/147456, 3072/65536, 3584/131072, 2500/589824, 6000/262144, 7000/524288, 6144/262144, 2500/589824, 6000/262144, 6144/262144, 2500/589824, 6000/262144, 6144/262144, 2500/589824, 6000/262144, 6144/262144, 2500/589824, 6000/262144, 6144/262144, 2500/589824, 6000/262144, 7168/524288, 5000/2359296, 12000/1048576, 14700/2097152, 12288/1048576, 5000/2359296, 12000/1048576, 12288/1048576, 5000/2359296, 12000/1048576, 14000/2048000],\nnonzero_percentages=[4.08%, 14.06%, 1.74%, 9.38%, 9.38%, 9.38%, 1.74%, 9.38%, 9.38%, 1.74%, 9.38%, 5.47%, 0.81%, 4.69%, 2.73%, 4.69%, 0.81%, 4.69%, 4.69%, 0.81%, 4.69%, 4.69%, 0.81%, 4.69%, 2.73%, 0.42%, 2.29%, 1.34%, 2.34%, 0.42%, 2.29%, 2.34%, 0.42%, 2.29%, 

In [5]:
pruner.get_global_sparsity_from_masks()

tensor(0.9900, device='cuda:0')

In [6]:
pruner.__str__()

'RigLScheduler(\nlayers=54,\nnonzero_params=[147/9408, 576/4096, 649/36864, 1560/16384, 1560/16384, 1536/16384, 649/36864, 1560/16384, 1536/16384, 649/36864, 1560/16384, 1792/32768, 1270/147456, 3120/65536, 3710/131072, 3072/65536, 1270/147456, 3120/65536, 3072/65536, 1270/147456, 3120/65536, 3072/65536, 1270/147456, 3120/65536, 3640/131072, 2520/589824, 6240/262144, 7490/524288, 6240/262144, 2520/589824, 6240/262144, 6240/262144, 2520/589824, 6240/262144, 6240/262144, 2520/589824, 6240/262144, 6240/262144, 2520/589824, 6240/262144, 6240/262144, 2520/589824, 6240/262144, 7420/524288, 5020/2359296, 12480/1048576, 14980/2097152, 12480/1048576, 5020/2359296, 12480/1048576, 12480/1048576, 5020/2359296, 12480/1048576, 14840/2048000],\nnonzero_percentages=[1.56%, 14.06%, 1.76%, 9.52%, 9.52%, 9.38%, 1.76%, 9.52%, 9.38%, 1.76%, 9.52%, 5.47%, 0.86%, 4.76%, 2.83%, 4.69%, 0.86%, 4.76%, 4.69%, 0.86%, 4.76%, 4.69%, 0.86%, 4.76%, 2.78%, 0.43%, 2.38%, 1.43%, 2.38%, 0.43%, 2.38%, 2.38%, 0.43%, 2.38%, 

In [7]:
pruner.W[0].shape

torch.Size([64, 3, 7, 7])

In [8]:
pruner.inital_ablated_filters

[63,
 0,
 53,
 196,
 196,
 0,
 53,
 196,
 0,
 53,
 196,
 0,
 118,
 452,
 442,
 0,
 118,
 452,
 0,
 118,
 452,
 0,
 118,
 452,
 116,
 246,
 964,
 954,
 16,
 246,
 964,
 16,
 246,
 964,
 16,
 246,
 964,
 16,
 246,
 964,
 16,
 246,
 964,
 372,
 502,
 1988,
 1978,
 272,
 502,
 1988,
 272,
 502,
 1988,
 860]

In [9]:
old_neurons_per_layer = pruner.inital_ablated_filters
old_neurons_per_layer

[63,
 0,
 53,
 196,
 196,
 0,
 53,
 196,
 0,
 53,
 196,
 0,
 118,
 452,
 442,
 0,
 118,
 452,
 0,
 118,
 452,
 0,
 118,
 452,
 116,
 246,
 964,
 954,
 16,
 246,
 964,
 16,
 246,
 964,
 16,
 246,
 964,
 16,
 246,
 964,
 16,
 246,
 964,
 372,
 502,
 1988,
 1978,
 272,
 502,
 1988,
 272,
 502,
 1988,
 860]

In [10]:
pruner.backward_masks[0][-1].sum() / pruner.backward_masks[0].numel()

tensor(0.0156, device='cuda:0')

In [11]:
1-pruner.S[0]

0.0419701390045516

In [12]:
from rigl_torch.utils.rigl_utils import get_fan_in_tensor


for idx, (m, n) in enumerate(list(zip(pruner.backward_masks, pruner.inital_ablated_filters))):
    fan_in_tens = get_fan_in_tensor(m)
    # print(fan_in_tens)
    if m.shape[0] < n:
        print(idx)

In [None]:
pruner.inital_ablated_filters[53]

1908

In [None]:
pruner.W[53].shape

torch.Size([1000, 2048])

In [None]:
pruner.W[-1].shape

torch.Size([1000, 2048])

In [24]:
pruner.S

[0.9580298609954484,
 0.8452836889704899,
 0.9822803541214066,
 0.9041948997086495,
 0.9041948997086495,
 0.9041948997086495,
 0.9822803541214066,
 0.9041948997086495,
 0.9041948997086495,
 0.9822803541214066,
 0.9041948997086495,
 0.9425764460986625,
 0.9913385313056129,
 0.9522462155380069,
 0.9713626058911724,
 0.9522462155380069,
 0.9913385313056129,
 0.9522462155380069,
 0.9522462155380069,
 0.9913385313056129,
 0.9522462155380069,
 0.9522462155380069,
 0.9913385313056129,
 0.9522462155380069,
 0.9713626058911724,
 0.9957188542140338,
 0.9761602991899241,
 0.9856998986560465,
 0.9761602991899241,
 0.9957188542140338,
 0.9761602991899241,
 0.9761602991899241,
 0.9957188542140338,
 0.9761602991899241,
 0.9761602991899241,
 0.9957188542140338,
 0.9761602991899241,
 0.9761602991899241,
 0.9957188542140338,
 0.9761602991899241,
 0.9761602991899241,
 0.9957188542140338,
 0.9761602991899241,
 0.9856998986560465,
 0.9978718242473238,
 0.9880894474501921,
 0.9928545982556383,
 0.9880894474

In [36]:
from rigl_torch.utils.rigl_utils import calculate_fan_in_and_fan_out


calculate_fan_in_and_fan_out(pruner.W[idx])

(1152, 1152)

In [35]:
pruner.inital_ablated_filters[idx]

62

In [39]:
pruner.W[idx].shape

torch.Size([128, 128, 3, 3])

In [34]:
idx = -2
get_fan_in_after_ablation(pruner.W[idx], num_neurons_to_ablate=pruner.inital_ablated_filters[idx], sparsity=pruner.S[idx])

13

In [44]:
pruner.inital_ablated_filters

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9, 62, 0, 62, 62, 62, 62, 0]

In [41]:
pruner.backward_masks[idx].sum()

tensor(1664, device='cuda:0')

In [53]:
pruner.backward_masks[idx][1].any()

tensor(True, device='cuda:0')

In [42]:
13 * (128-62)

858

In [40]:
13 * (128-62) / pruner.W[idx].numel()

0.005818684895833333

In [6]:
def get_global_sparsity_from_masks(pruner) -> float:
    total_els = 0
    total_non_zero_els = 0
    for w, m in list(zip(pruner.W, pruner.backward_masks)):
        if m is None:
            total_non_zero_els += w.numel()
            total_els += w.numel()
        else:
            total_non_zero_els += m.sum()
            total_els += w.numel()
    return 1 - (total_non_zero_els / total_els)

get_global_sparsity_from_masks(pruner)

tensor(0.9904, device='cuda:0')

In [7]:
total_el = 0
non_zero_els = 0
for w, m,s in list(zip(pruner.W, pruner.backward_masks, pruner.S)):
    total_el += m.numel()
    non_zero_els += m.sum()
1-(non_zero_els / total_el)

tensor(0.9904, device='cuda:0')

In [8]:
pruner.get_global_sparsity_from_masks()

tensor(0.9904, device='cuda:0')

In [8]:
pruner.inital_ablated_filters

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9, 62, 0, 62, 62, 62, 62, 0]

In [9]:
total_el = 0 
non_zero_el = 0
for m in pruner.backward_masks:
    
    

TypeError: expected Tensor as element 0 in argument 0, but got list

In [7]:
pruner.get_global_sparsity_from_masks()

tensor(0.0151, device='cuda:0')

In [5]:
pruner.__str__()

'RigLScheduler(\nlayers=23,\nnonzero_params=[85/432, 183/4608, 237/9216, 170/512, 237/9216, 237/9216, 237/9216, 237/9216, 345/18432, 454/36864, 332/2048, 454/36864, 454/36864, 454/36864, 454/36864, 670/73728, 886/147456, 656/8192, 886/147456, 886/147456, 886/147456, 886/147456, 467/1280],\nnonzero_percentages=[19.68%, 3.97%, 2.57%, 33.20%, 2.57%, 2.57%, 2.57%, 2.57%, 1.87%, 1.23%, 16.21%, 1.23%, 1.23%, 1.23%, 1.23%, 0.91%, 0.60%, 8.01%, 0.60%, 0.60%, 0.60%, 0.60%, 36.48%],\ntotal_nonzero_params=10793/1076912 (1.00%),\ntotal_CONV_nonzero_params=10326/1075632 (0.96%),\nstep=0,\nnum_rigl_steps=0,\nignoring_linear_layers=False,\nsparsity_distribution=erk,\nITOP rate=0.0100,\n)'

In [6]:
def get_global_sparsity_from_masks(pruner) -> float:
    total_els = [m.numel() for m in pruner.backward_masks]
    print(total_els)
    non_zero_els = [m.sum().item() for m in pruner.backward_masks]
    print(non_zero_els)
    print(sum(non_zero_els))
    print(sum(total_els))
    return sum(non_zero_els) / sum(total_els)

get_global_sparsity_from_masks(pruner)

#todo what?

[432, 4608, 9216, 512, 9216, 9216, 9216, 9216, 18432, 36864, 2048, 36864, 36864, 36864, 36864, 73728, 147456, 8192, 147456, 147456, 147456, 147456, 1280]
[85, 183, 237, 170, 237, 237, 237, 237, 345, 454, 332, 454, 454, 454, 454, 670, 886, 656, 886, 886, 886, 886, 467]
10793
1076912


0.010022174513795

In [4]:
# from rigl_torch.utils.rigl_utils import calculate_fan_in_and_fan_out
# @torch.no_grad()
# def random_sparsify(pruner) -> None:
#     """Randomly sparsifies model to desired sparsity distribution with
#     constant fan in.
#     """
#     is_dist: bool = dist.is_initialized()
#     pruner.backward_masks = []
#     for idx, (w, num_neurons_to_ablate) in enumerate(
#         list(zip(pruner.W, pruner.ablated_filters))
#     ):
#         # if sparsity is 0%, skip
#         if pruner.S[idx] <= 0:
#             pruner.backward_masks.append(None)
#             continue

#         dense_fan_in, _ = calculate_fan_in_and_fan_out(module=w)
#         fan_in = get_fan_in_after_ablation(
#             weight_tensor=w,
#             num_neurons_to_ablate=num_neurons_to_ablate,
#             sparsity=pruner.S[idx],
#         )
#         print(fan_in)
#         print(dense_fan_in)
#         # Number of connections to drop per filter
#         s = dense_fan_in - fan_in
#         print(f"s is {s}")
#         perm = torch.concat(
#             [
#                 torch.randperm(fan_in).reshape(1, -1)
#                 for _ in range(w.shape[0])
#             ]
#         )
#         # Generate random perm of indices to mask per filter / neuron
#         perm = perm[
#             :, :s
#         ]  # Drop s elements from n to achieve desired sparsity
#         print(perm)
#         print(f"perm shape: {perm.shape}")
#         mask = torch.concat(
#             [torch.ones(dense_fan_in).reshape(1, -1) for _ in range(w.shape[0])]
#         )
#         print(f"mask shape: {mask.shape}")
#         for filter_idx in range(mask.shape[0]):  # TODO: vectorize?
#             mask[filter_idx][perm[filter_idx]] = 0
#         mask = mask.reshape(w.shape).to(device=w.device)
#         # Ablate top n neurons according to filter sparsity criterion
#         mask[num_neurons_to_ablate:] = False

#         if is_dist:
#             dist.broadcast(mask, 0)
#         mask = mask.bool()
#         w *= mask
#         pruner.backward_masks.append(mask)
#     return pruner

# const_fan_pruner = random_sparsify(pruner)

In [5]:
pruner.inital_ablated_filters   # TODO: Make sure this matches below!

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9, 62, 0, 62, 62, 62, 62, 0]

In [7]:
def _update_current_filter_ablation(pruner) -> None:
    def get_num_ablated_filters(mask) -> int:
        if mask is None:
            return 0
        else:
            return torch.sum(
                torch.stack([~filter.any() for filter in mask])
            ).item()

    ablated_filters = [
        get_num_ablated_filters(filter) for filter in pruner.backward_masks
    ]
    return ablated_filters

_update_current_filter_ablation(pruner)

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 2, 0, 0, 0, 0, 0]

In [14]:
pruner.S[-6]

0.9199324398812323

In [13]:
pruner.backward_masks[-6]

for filter in pruner.backward_masks[-6]:
    if  ~filter.any():
        print(filter.shape)
        print(filter.any())

torch.Size([64, 1, 1])
tensor(False, device='cuda:0')
torch.Size([64, 1, 1])
tensor(False, device='cuda:0')


In [8]:
non_zero_filters = 0
for f in mask:
    if f.any():
        non_zero_filters+=1
non_zero_filters

66

In [28]:
pruner.backward_masks[21][0].sum()

tensor(8, device='cuda:0')

In [29]:
pruner.backward_masks[21].sum() / pruner.backward_masks[21].numel()

tensor(0.0060, device='cuda:0')

In [30]:
1-pruner.S[21]

0.0060073598943634066

In [32]:
pruner.backward_masks[21].shape

torch.Size([128, 128, 3, 3])

In [34]:
8*128 / pruner.backward_masks[21].numel()

0.006944444444444444

In [35]:
pruner.backward_masks[21].numel()

147456

In [36]:
8*128 / 147456

0.006944444444444444

In [38]:
18*62 / 147456

0.007568359375

In [40]:
456/76032

0.005997474747474747

In [51]:
from rigl_torch.utils.rigl_utils import calculate_fan_in_and_fan_out
import math
idx=21
weight_tensor = pruner.W[idx]
num_neurons_to_ablate = pruner.ablated_filters[idx]
sparsity = pruner.S[idx]
active_neurons = weight_tensor.shape[0] - num_neurons_to_ablate
print(active_neurons)
remaining_non_zero_elements = math.floor(weight_tensor.numel() * (1 - sparsity))
print(remaining_non_zero_elements)
remaining_non_zero_elements // active_neurons

66
885


13

In [22]:
pruner.ablated_filters[21]

62

In [21]:
1-(6.9*62 / weight_tensor.numel())

0.9970987955729167

In [11]:
sparsity

0.9939926401056366

In [26]:
weight_tensor.shape

torch.Size([128, 128, 3, 3])

In [29]:
6*66 / weight_tensor.numel()

0.002685546875

In [31]:
1-sparsity

0.0060073598943634066

In [None]:
def get_filter_sparsity(mask): 
    print()

In [11]:
(dense_fan_in * (pruner.W[21].shape[0]-62) * (1-pruner.S[21])) / pruner.W[21][:].numel()

0.0030975449455311315

In [12]:
pruner.W[21][62:].shape[0]

66

In [13]:
(dense_fan_in * (pruner.W[21].shape[0]-62) * (1-pruner.S[21])) 

456.75158748823856

In [14]:
get_fan_in_after_ablation(
    weight_tensor = pruner.W[21],
    num_neurons_to_ablate=62,
    sparsity=pruner.S[21]
)

6

In [52]:
remaining_els * 

71424

In [37]:
pruner.backward_masks[21].sum() / pruner.backward_masks[21].numel()

tensor(0.0060, device='cuda:0')

In [34]:
1-pruner.S[21]

0.0060073598943634066

In [8]:
import math
for idx, (w,s) in enumerate(list(zip(pruner.W, pruner.S))):
    if s is None: 
        continue
    unadjusted_fan_in = w.shape[1]*math.prod(w.shape[2:])
    sparse_fan_in = int( (1-s) * unadjusted_fan_in)
    out_channels = w.shape[1]
    receptive_field_size=9
    unadjusted_filter_sparsity = sparse_fan_in / (out_channels * receptive_field_size)
    if unadjusted_filter_sparsity < 0.01:
        print(unadjusted_filter_sparsity)
        print(idx)
    

0.008680555555555556
15
0.005208333333333333
16
0.008680555555555556
17
0.005208333333333333
18
0.005208333333333333
19
0.005208333333333333
20
0.005208333333333333
21


In [4]:
pruner.ablated_filters

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

In [32]:
from rigl_torch.utils.rigl_utils import calculate_fan_in_and_fan_out
idx=0

w = pruner.W[idx]
fan_in, fan_out = calculate_fan_in_and_fan_out(w)
s = int(fan_in * pruner.S[idx])
perm = torch.concat(
    [
        torch.randperm(fan_in).reshape(1, -1)
        for i in range(w.shape[0])
    ]
)
perm.shape

torch.Size([16, 27])

In [36]:
pruner.S[idx]

0.8043404410996132

In [34]:
(27-21)/27

0.2222222222222222

In [38]:
perm = perm[
:, :s
]  # Drop s elements from n to achieve desired sparsity
perm.shape

torch.Size([16, 21])

In [20]:
import math

math.prod(w.shape[1:])

144

In [18]:
w.shape

torch.Size([32, 16, 3, 3])

In [7]:
pruner.S

[0.8043404410996132,
 0.9603789393226717,
 0.9743196828943242,
 0.6698244943555973,
 0.9743196828943242,
 0.9743196828943242,
 0.9743196828943242,
 0.9743196828943242,
 0.9812900546801505,
 0.9877101339565695,
 0.8382140022342427,
 0.9877101339565695,
 0.9877101339565695,
 0.9877101339565695,
 0.9877101339565695,
 0.9909201735947789,
 0.9939926401056366,
 0.9199324398812323,
 0.9939926401056366,
 0.9939926401056366,
 0.9939926401056366,
 0.9939926401056366,
 0.6354862417685794]

In [41]:
conv1 = model.get_submodule("conv1")
conv1

Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)

In [44]:
torch.nn.init._calculate_fan_in_and_fan_out(conv1.weight)

(27, 144)

In [45]:
144/16

9.0

In [5]:
for m in pruner.backward_masks:
    if m is None:
        continue
    else:
        break
        print(m.shape)

In [6]:
pruner.S[1]

0.9603789393226717

In [7]:
pruner.backward_masks[-2].shape

torch.Size([128, 128, 3, 3])

In [8]:
m = pruner.backward_masks[20]

In [9]:
pruner.S[20]

0.9939926401056366

In [10]:
m.shape

torch.Size([128, 128, 3, 3])

In [11]:
len(pruner.backward_masks)

23

In [12]:
m.shape

torch.Size([128, 128, 3, 3])

In [13]:
pruner.S[20]

0.9939926401056366

In [14]:
1 - ( m.sum() / m.numel() )

tensor(0.9940, device='cuda:0')

In [15]:
filter_abalation_mask = torch.ones(size=m.shape, dtype=torch.bool)

In [16]:
m[:1].shape

torch.Size([1, 128, 3, 3])

In [17]:
pruner.S[20]

0.9939926401056366

In [18]:
1-pruner.S[20]

0.0060073598943634066

In [19]:
def get_filter_s(filter) -> float:
    return (filter.sum() / filter.numel()).item()

filter_sparsities = list(map(get_filter_s, m))
avg_filter_s = sum(filter_sparsities)/len(filter_sparsities)
if avg_filter_s < 0.1:
    print(avg_filter_s)
m[0].numel()
num_filters = m.numel()
# m.shape[0] * m[0].numel()



0.006008572149767133


In [20]:
def get_filters_to_prune(mask):
    m = mask
    kernel_size = m.shape[-2] * m.shape[-1]
    in_channels = m.shape[0]
    out_channels = m.shape[1]
    avg_filter_s = []
    for filter in m:
        avg_filter_s.append((filter.sum() / filter.numel()).item())
    print(sum(avg_filter_s) / len(avg_filter_s))
    print(torch.std(torch.tensor(avg_filter_s)).item())
        
for idx, mask in enumerate(pruner.backward_masks):
    # if idx != 20:
    #     continue
    print(f"Layer {idx}: ")
    if mask is None:
        print( "NONE")
    else:
        get_filters_to_prune(mask)
        

# filter_abalation_mask = torch.ones(shape=mask.shape, dtype=torch.bool)
        

Layer 0: 
0.196759263984859
0.08629842102527618
Layer 1: 
0.039713542646495625
0.013802867382764816
Layer 2: 
0.025716146221384406
0.008679855614900589
Layer 3: 
0.33203125
0.1387929469347
Layer 4: 
0.025716146228660364
0.008266767486929893
Layer 5: 
0.025716146221384406
0.009327770210802555
Layer 6: 
0.025716146337799728
0.00817213486880064
Layer 7: 
0.02571614623593632
0.011185742914676666
Layer 8: 
0.018717448412644444
0.00805725622922182
Layer 9: 
0.012315538364418899
0.004985100124031305
Layer 10: 
0.162109375
0.06678630411624908
Layer 11: 
0.012315538457187358
0.004511543083935976
Layer 12: 
0.01231553842080757
0.0046781389974057674
Layer 13: 
0.012315538442635443
0.005127036478370428
Layer 14: 
0.012315538398979697
0.0044474611058831215
Layer 15: 
0.009087456804991234
0.003371547209098935
Layer 16: 
0.00600857215204087
0.002363148145377636
Layer 17: 
0.080078125
0.033391211181879044
Layer 18: 
0.0060085721570430906
0.002108385320752859
Layer 19: 
0.0060085721570430906
0.00247115

5/9

In [21]:
l = torch.nn.Conv2d(in_channels=3, out_channels=32, kernel_size=(3,3))

In [22]:
l.weight.shape

torch.Size([32, 3, 3, 3])

In [23]:
m.shape

torch.Size([128, 128, 3, 3])

In [24]:
Checkpoint.parent_dir

PosixPath('/home/user/condensed-sparsity')

In [25]:
cfg.paths.checkpoints

'/home/user/condensed-sparsity/artifacts/checkpoints'

In [26]:
# 99% sparse runs
from rigl_torch.utils.checkpoint import Checkpoint
const_fan_in_run_id = "2d4v4ezc"
vanilla_rigl_run_id = "xhnqnd6c"

const_fan_ckp = Checkpoint.load_last_checkpoint(run_id=const_fan_in_run_id, parent_dir=cfg.paths.checkpoints)

INFO:rigl_torch.utils.checkpoint:Loading checkpoint from /home/user/condensed-sparsity/artifacts/checkpoints/20221009_2d4v4ezc/checkpoint.pt.tar...


In [27]:
vanilla_rigl_ckp = Checkpoint.load_last_checkpoint(run_id=vanilla_rigl_run_id, parent_dir=cfg.paths.checkpoints)

INFO:rigl_torch.utils.checkpoint:Loading checkpoint from /home/user/condensed-sparsity/artifacts/checkpoints/20221011_xhnqnd6c/checkpoint.pt.tar...


In [28]:
const_fan_masks =const_fan_ckp.pruner["backward_masks"]
vanilla_rigl_masks =vanilla_rigl_ckp.pruner["backward_masks"]

In [29]:
1 - ( vanilla_rigl_masks[35].sum() / vanilla_rigl_masks[35].numel() ) 

tensor(0.9957, device='cuda:0')

In [30]:
total = 0
zeros = 0
for filter in vanilla_rigl_masks[35]:
    total+=1
    if filter.any():
        # print("Not zero!")
        continue
    else:
        # print("zero")
        zeros+=1
print(zeros/total)
print(zeros)

0.2421875
62


In [31]:
non_zero = []
max = 0
for filter in vanilla_rigl_masks[35]:
    non_zero.append(filter.sum())
    if max < filter.sum().item():
        max = filter.sum().item()
non_zero = torch.stack(non_zero).type(torch.float32)
print(torch.mean(non_zero))
print(torch.std(non_zero))
print(max)

tensor(9.8672, device='cuda:0')
tensor(10.5944, device='cuda:0')
46


In [32]:
vanilla_rigl_masks[35].sum() / vanilla_rigl_masks[35].numel() 

tensor(0.0043, device='cuda:0')

In [33]:
1/256*100

0.390625

In [34]:
9/(256*3*3)

0.00390625

In [35]:
1 - ( vanilla_rigl_masks[35].sum() / vanilla_rigl_masks[35].numel() ) 

tensor(0.9957, device='cuda:0')

In [36]:
46/(256*3*3)*100

1.9965277777777777

In [37]:
10/(256*3*3)*100

0.4340277777777778

In [38]:
thres = 0.5
n = 0.5/100/fan_in*()
fan_in=10
fan_in/((256-n)*3*3)*100

NameError: name 'fan_in' is not defined

In [None]:
256-10*100/(0.5*9)

: 

In [None]:
def get_filters_to_prune(mask):
    m = mask
    kernel_size = m.shape[-2] * m.shape[-1]
    in_channels = m.shape[0]
    out_channels = m.shape[1]
    mask_sparsity = 1 - (mask.sum() / mask.numel()).item()
    fan_in = mask[0].sum()
    target_filter_sparsity_percent = 0.5
    return out_channels - fan_in * 100 / (target_filter_sparsity_percent * kernel_size)
    avg_filter_s = []
    for filter in m:
        avg_filter_s.append((filter.sum() / filter.numel()).item())
    print(sum(avg_filter_s) / len(avg_filter_s))
    print(torch.std(torch.tensor(avg_filter_s)).item())

: 

In [None]:
fan_in=10
n=33.78
fan_in/((256-n)*3*3)*100

: 

In [None]:
vanilla_rigl_masks[35].sum() / vanilla_rigl_masks[35].numel()*100

: 

In [None]:
vanilla_rigl_masks[35].numel()

: 

In [None]:
vanilla_rigl_masks[35].sum() / ((256-62)*256*3*3) * 100

: 

In [None]:
vanilla_rigl_masks[35].sum() / ((256-0)*256*3*3) * 100

: 

In [None]:
62*256*3*3 / (256*256*3*3)

: 

In [None]:
non_zero.type(torch.float32)

: 

In [None]:
(non_zero>9).sum()

: 

In [None]:
194/256

: 

In [None]:
62*256*3*3 / (256*256*3*3)

: 

In [None]:
vanilla_rigl_masks[35].sum() / ((256-62)*256*3*3)

: 

In [None]:
vanilla_rigl_masks[35].sum() / (256*256*3*3)

: 

In [None]:
vanilla_rigl_ckp.pruner["S"][35]

: 

In [None]:
1-vanilla_rigl_ckp.pruner["S"][35]

: 

In [None]:
# Compare this with model weights to close loop on investigating 00 weights

: 

In [None]:
vanilla_rigl_ckp.pruner.keys()

: 

In [None]:
for idx, f in enumerate(vanilla_rigl_masks[35]):
    print(1 - (f.sum() / f.numel()))

: 

: 

In [None]:
10/(256*9)

: 

In [None]:
const_fan_masks[35][0].sum()

: 

In [None]:
for idx, mask in enumerate(const_fan_masks):
    if idx != 35:
        continue
    print(f"Layer {idx}: ")
    if mask is None:
        print( "NONE")
    else:
        get_filters_to_prune(mask)
        
for idx, mask in enumerate(vanilla_rigl_masks):
    if idx != 35:
        continue
    print(f"Layer {idx}: ")
    if mask is None:
        print( "NONE")
    else:
        get_filters_to_prune(mask)

: 