In [64]:
%matplotlib inline
import numpy as np
import math
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from rigl_torch.models import ModelFactory
from rigl_torch.optim.cosine_annealing_with_linear_warm_up import CosineAnnealingWithLinearWarmUp
from rigl_torch.rigl_constant_fan import RigLConstFanScheduler
from rigl_torch.rigl_scheduler import RigLScheduler

INFO:/home/condensed-sparsity/src/rigl_torch/models/model_factory.py:Registering mnist for mnist dataset to ModelFactory...
INFO:/home/condensed-sparsity/src/rigl_torch/models/model_factory.py:Registering resnet18 for cifar10 dataset to ModelFactory...
INFO:/home/condensed-sparsity/src/rigl_torch/models/model_factory.py:Registering wide_resnet22 for cifar10 dataset to ModelFactory...
  warn(f"Failed to load image Python extension: {e}")
INFO:/home/condensed-sparsity/src/rigl_torch/models/model_factory.py:Registering cond_net for mnist dataset to ModelFactory...


In [65]:
from rigl_torch.datasets import get_dataloaders
from omegaconf import DictConfig
import hydra

with hydra.initialize(config_path="../configs"):
    cfg = hydra.compose(config_name="config.yaml", overrides=[])
cfg

The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  with hydra.initialize(config_path="../configs"):


{'dataset': {'name': 'mnist', 'num_classes': 10, 'train_len': 60000}, 'model': {'name': 'mnist'}, 'experiment': {'comment': 'missing_sweep_${rigl.dense_allocation}_${rigl.const_fan_in}', 'name': '${model.name}_${dataset.name}_${experiment.comment}', 'sweep': True, 'resume_from_checkpoint': True, 'run_id': '1hz0vtiw'}, 'paths': {'data_folder': '/home/condensed-sparsity/data', 'artifacts': '/home/condensed-sparsity/artifacts', 'logs': '/home/condensed-sparsity/logs', 'checkpoints': '${paths.artifacts}/checkpoints'}, 'rigl': {'dense_allocation': 0.5, 'delta': 100, 'grad_accumulation_n': 1, 'alpha': 0.3, 'static_topo': 0, 'const_fan_in': False, 'sparsity_distribution': 'erk', 'erk_power_scale': 1.0}, 'training': {'dry_run': False, 'batch_size': 128, 'test_batch_size': 1000, 'epochs': 250, 'seed': 42, 'log_interval': 10000, 'save_model': True, 'optimizer': 'sgd', 'weight_decay': 0.0005, 'momentum': 0.9, 'scheduler': 'step_lr', 'lr': 0.1, 'init_lr': 0, 'warm_up_steps': 5, 'gamma': 0.2, 'step

In [66]:
cfg.rigl.sparsity_distribution = "erk"

In [67]:
# model = ModelFactory.load_model(model="mnist", dataset='mnist').to(device)
use_cuda = not cfg.compute.no_cuda and torch.cuda.is_available()
torch.manual_seed(cfg.training.seed)
device = torch.device("cuda" if use_cuda else "cpu")
train_loader, test_loader = get_dataloaders(cfg)

model = ModelFactory.load_model(model=cfg.model.name, dataset=cfg.dataset.name).to(device)
# model = get_model(cfg).to(device)
optimizer = torch.optim.Adadelta(model.parameters(), lr=cfg.training.lr)
scheduler = CosineAnnealingWithLinearWarmUp(
    optimizer,
    T_max=cfg.training.epochs,
    eta_min=0,
    lr=cfg.training.lr,
    warm_up_steps=cfg.training.warm_up_steps,
)

pruner = lambda: True  # noqa: E731
if cfg.rigl.dense_allocation is not None:
    T_end = int(0.75 * cfg.training.epochs * len(train_loader))
    if cfg.rigl.const_fan_in:
        rigl_scheduler = RigLConstFanScheduler
    else:
        rigl_scheduler = RigLScheduler
    pruner = rigl_scheduler(
        model,
        optimizer,
        dense_allocation=cfg.rigl.dense_allocation,
        alpha=cfg.rigl.alpha,
        delta=cfg.rigl.delta,
        static_topo=cfg.rigl.static_topo,
        T_end=T_end,
        ignore_linear_layers=False,
        grad_accumulation_n=cfg.rigl.grad_accumulation_n,
        sparsity_distribution=cfg.rigl.sparsity_distribution,
        erk_power_scale=cfg.rigl.erk_power_scale,
    )
else:
    print(
        "cfg.rigl.dense_allocation is `null`, training with dense "
        "network..."
    )

INFO:/home/condensed-sparsity/src/rigl_torch/models/model_factory.py:Loading model mnist/mnist using <function Mnist at 0x7f03dc419940> with args: () and kwargs: {}
INFO:/home/condensed-sparsity/src/rigl_torch/rigl_scheduler.py:Sparsity of layer at index 0 set to 0.0
INFO:/home/condensed-sparsity/src/rigl_torch/rigl_scheduler.py:Sparsity of layer at index 3 set to 0.0


INFO:/home/condensed-sparsity/src/rigl_torch/models/model_factory.py:Loading model resnet18/cifar10 using <function ResNet18 at 0x7f00e14fb160> with args: () and kwargs: {}
INFO:/home/condensed-sparsity/src/rigl_torch/rigl_scheduler.py:Sparsity of layer at index 20 set to 0.0
INFO:/home/condensed-sparsity/src/rigl_torch/rigl_scheduler.py:Sparsity of layer at index 0 set to 0.0
INFO:/home/condensed-sparsity/src/rigl_torch/rigl_scheduler.py:Sparsity of layer at index 7 set to 0.0
INFO:/home/condensed-sparsity/src/rigl_torch/rigl_scheduler.py:Sparsity of layer at index 12 set to 0.0
INFO:/home/condensed-sparsity/src/rigl_torch/rigl_scheduler.py:Sparsity of layer at index 17 set to 0.0
INFO:/home/condensed-sparsity/src/rigl_torch/rigl_scheduler.py:Sparsity of layer at index 1 set to 0.0
INFO:/home/condensed-sparsity/src/rigl_torch/rigl_scheduler.py:Sparsity of layer at index 2 set to 0.0
INFO:/home/condensed-sparsity/src/rigl_torch/rigl_scheduler.py:Sparsity of layer at index 3 set to 0.0


In [68]:
from rigl_torch.utils.checkpoint import Checkpoint
checkpoint = Checkpoint(
    run_id = "test2",
    cfg = cfg,
    model = model,
    optimizer=optimizer,
    scheduler=scheduler,
    pruner=pruner,
)

# checkpoint = Checkpoint.load_last_checkpoint(run_id = "1hz0vtiw")

In [70]:
def train(
    cfg, model, device, train_loader, optimizer, epoch, pruner, scheduler, step
):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        step += 1
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        logits = model(data)
        output = F.log_softmax(logits, dim=1)
        loss = F.nll_loss(output, target)
        loss.backward()

        if pruner():
            optimizer.step()
        scheduler.step()

        if step % cfg.training.log_interval == 0:
            print(
                "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                    epoch,
                    batch_idx * len(data),
                    len(train_loader.dataset),
                    100.0 * batch_idx / len(train_loader),
                    loss.item(),
                )
            )
        if cfg.training.dry_run:
            print("Dry run, exiting after one training step")
            return step
    return step

In [72]:
checkpoint._update_best_flag()

INFO:rigl_torch.utils.checkpoint:New best checkpoint accuracy (1.000000 > -inf)!


In [9]:
checkpoint.is_best

True

In [10]:
checkpoint.save_checkpoint()

INFO:rigl_torch.utils.checkpoint:Checkpoint state saved!


In [11]:
Checkpoint.parent_dir.__str__()

'/home/condensed-sparsity/artifacts/checkpoints'

In [12]:
checkpoint = Checkpoint.load_last_checkpoint(run_id="test")

INFO:rigl_torch.utils.checkpoint:Loading checkpoint from /home/condensed-sparsity/artifacts/checkpoints/20220819_test/checkpoint.pt.tar...


In [14]:
checkpoint._update_state()

{'run_id': 'test',
 'cfg': {'dataset': {'name': 'cifar10', 'normalize': False, 'num_classes': 10, 'classes': ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'], 'train_len': 50000}, 'model': {'name': 'resnet18'}, 'experiment': {'comment': 'missing_sweep_${rigl.dense_allocation}_${rigl.const_fan_in}', 'name': '${model.name}_${dataset.name}_${experiment.comment}', 'sweep': True}, 'paths': {'data_folder': '/home/condensed-sparsity/data', 'artifacts': '/home/condensed-sparsity/artifacts', 'logs': '/home/condensed-sparsity/logs', 'checkpoints': '${paths.artifacts}/checkpoints'}, 'rigl': {'dense_allocation': 0.5, 'delta': 100, 'grad_accumulation_n': 1, 'alpha': 0.3, 'static_topo': 0, 'const_fan_in': False, 'sparsity_distribution': 'erk', 'erk_power_scale': 1.0}, 'training': {'dry_run': False, 'batch_size': 128, 'test_batch_size': 1000, 'epochs': 250, 'seed': 42, 'log_interval': 10000, 'save_model': True, 'optimizer': 'sgd', 'weight_decay': 0.0005, 'mom

In [5]:

def train(
    cfg, model, device, train_loader, optimizer, epoch, pruner, scheduler, step
):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        step += 1
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        logits = model(data)
        output = F.log_softmax(logits, dim=1)
        loss = F.nll_loss(output, target)
        loss.backward()

        if pruner():
            optimizer.step()
        scheduler.step()

        if step % cfg.training.log_interval == 0:
            print(
                "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                    epoch,
                    batch_idx * len(data),
                    len(train_loader.dataset),
                    100.0 * batch_idx / len(train_loader),
                    loss.item(),
                )
            )
        if cfg.training.dry_run:
            print("Dry run, exiting after one training step")
            return step
    return step

In [28]:
cfg.training.epochs = 2 
step = 0
for epoch in range(1, cfg.training.epochs + 1):
        print(pruner)
        step = train(
            cfg,
            model,
            device,
            train_loader,
            optimizer,
            epoch,
            pruner=pruner,
            scheduler=scheduler,
            step=step,
        )
# checkpoint.current_acc = acc
checkpoint.step = step
checkpoint.epoch = epoch
checkpoint.save_checkpoint()

RigLScheduler(
layers=4,
nonzero_params=[288/288, 6461/18432, 591796/1179648, 1280/1280],
nonzero_percentages=[100.00%, 35.05%, 50.17%, 100.00%],
total_nonzero_params=599825/1199648 (50.00%),
total_CONV_nonzero_params=6749/18720 (36.05%),
step=0,
num_rigl_steps=0,
ignoring_linear_layers=False,
sparsity_distribution=erk,
)




RigLScheduler(
layers=4,
nonzero_params=[288/288, 6461/18432, 591796/1179648, 1280/1280],
nonzero_percentages=[100.00%, 35.05%, 50.17%, 100.00%],
total_nonzero_params=599825/1199648 (50.00%),
total_CONV_nonzero_params=6749/18720 (36.05%),
step=469,
num_rigl_steps=4,
ignoring_linear_layers=False,
sparsity_distribution=erk,
)


INFO:rigl_torch.utils.checkpoint:New best checkpoint accuracy (0.000000 > -inf)!
INFO:rigl_torch.utils.checkpoint:Checkpoint state saved!
INFO:rigl_torch.utils.checkpoint:Best checkpoint state saved!


In [73]:
model_org = model

In [74]:
hex(id(model_org))

'0x7f0460366a00'

In [75]:
hex(id(model))

'0x7f0460366a00'

In [76]:
hex(id(checkpoint.model))

'0x7f0460366a00'

In [77]:
checkpoint = Checkpoint.load_last_checkpoint(run_id = "test2")

INFO:rigl_torch.utils.checkpoint:Loading checkpoint from /home/condensed-sparsity/artifacts/checkpoints/20220822_test2/checkpoint.pt.tar...


In [79]:
model = ModelFactory.load_model(model="mnist", dataset="mnist", state_dict = checkpoint.model)

INFO:/home/condensed-sparsity/src/rigl_torch/models/model_factory.py:Loading model mnist/mnist using <function Mnist at 0x7f03dc419940> with args: () and kwargs: {}


In [81]:
hex(id(model))

'0x7f03d1293130'

In [85]:
model == model_org

False

In [97]:
model_org.to(device)
model.to(device)

MnistNet(
  (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (dropout1): Dropout(p=0.25, inplace=False)
  (dropout2): Dropout(p=0.5, inplace=False)
  (fc1): Linear(in_features=9216, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
)

In [106]:
for (k_org, v_org), (k, v) in list(zip(model_org.state_dict().items(), model.state_dict().items())):
    print(k)
    print((v == v_org).all())

conv1.weight
tensor(True, device='cuda:0')
conv1.bias
tensor(True, device='cuda:0')
conv2.weight
tensor(True, device='cuda:0')
conv2.bias
tensor(True, device='cuda:0')
fc1.weight
tensor(True, device='cuda:0')
fc1.bias
tensor(True, device='cuda:0')
fc2.weight
tensor(True, device='cuda:0')
fc2.bias
tensor(True, device='cuda:0')


In [107]:
pruner_org = pruner

In [108]:
import torch
import torch.nn.functional as F
import pytorch_lightning as pl
import random
from torch.utils.tensorboard import SummaryWriter
import omegaconf
import hydra
import logging
import wandb
import pathlib
from rigl_torch.models.model_factory import ModelFactory

from rigl_torch.rigl_scheduler import RigLScheduler
from rigl_torch.rigl_constant_fan import RigLConstFanScheduler
from rigl_torch.datasets import get_dataloaders
from rigl_torch.optim import (
    get_optimizer,
    get_lr_scheduler,
)
from rigl_torch.utils.checkpoint import Checkpoint

checkpoint = Checkpoint.load_last_checkpoint(
            run_id="test2"
)
_RESUME_FROM_CHECKPOINT = True
wandb_init_resume = "must"
run_id = checkpoint.run_id
optimizer_state = checkpoint.optimizer
scheduler_state = checkpoint.scheduler
pruner_state = checkpoint.pruner
model_state = checkpoint.model
cfg = checkpoint.cfg    
use_cuda = not cfg.compute.no_cuda and torch.cuda.is_available()

device = torch.device("cuda" if use_cuda else "cpu")
train_loader, test_loader = get_dataloaders(cfg)

model = ModelFactory.load_model(
    model=cfg.model.name, dataset=cfg.dataset.name, state_dict=model_state
)
model.to(device)

optimizer = get_optimizer(cfg, model, state_dict=optimizer_state)
scheduler = get_lr_scheduler(cfg, optimizer, state_dict=scheduler_state)

pruner = None  # noqa: E731
if cfg.rigl.dense_allocation is not None:
    T_end = int(0.75 * cfg.training.epochs * len(train_loader))
    if cfg.rigl.const_fan_in:
        rigl_scheduler = RigLConstFanScheduler
    else:
        rigl_scheduler = RigLScheduler
    pruner = rigl_scheduler(
        model,
        optimizer,
        dense_allocation=cfg.rigl.dense_allocation,
        alpha=cfg.rigl.alpha,
        delta=cfg.rigl.delta,
        static_topo=cfg.rigl.static_topo,
        T_end=T_end,
        ignore_linear_layers=False,
        grad_accumulation_n=cfg.rigl.grad_accumulation_n,
        sparsity_distribution=cfg.rigl.sparsity_distribution,
        erk_power_scale=cfg.rigl.erk_power_scale,
        state_dict=pruner_state,
    )

INFO:rigl_torch.utils.checkpoint:Loading checkpoint from /home/condensed-sparsity/artifacts/checkpoints/20220822_test2/checkpoint.pt.tar...
INFO:/home/condensed-sparsity/src/rigl_torch/models/model_factory.py:Loading model mnist/mnist using <function Mnist at 0x7f03dc419940> with args: () and kwargs: {}


In [115]:
for (k_org, v_org), (k, v) in list(zip(pruner_org.state_dict().items(), pruner.state_dict().items())):
    print(k)
    if k == "backward_masks":
        continue
    if type(v) == torch.Tensor or k =="backward_masks":
        print((v == v_org).all())
    else:
        print(v==v_org)

dense_allocation
True
S
True
N
True
hyperparams
True
step
True
rigl_steps
True
backward_masks
_linear_layers_mask
True


In [125]:
p.numel()

288

In [128]:
non_zero = 0
numel = 0
for p in model.parameters():
    print(torch.count_nonzero(p) / p.numel())
    non_zero +=torch.count_nonzero(p)
    numel += p.numel()

non_zero/numel

tensor(1., device='cuda:0')
tensor(1., device='cuda:0')
tensor(0.3505, device='cuda:0')
tensor(1., device='cuda:0')
tensor(0.5017, device='cuda:0')
tensor(1., device='cuda:0')
tensor(1., device='cuda:0')
tensor(1., device='cuda:0')


tensor(0.5001, device='cuda:0')

In [118]:
pruner_org.__str__()

'RigLScheduler(\nlayers=4,\nnonzero_params=[288/288, 6461/18432, 591796/1179648, 1280/1280],\nnonzero_percentages=[100.00%, 35.05%, 50.17%, 100.00%],\ntotal_nonzero_params=599825/1199648 (50.00%),\ntotal_CONV_nonzero_params=6749/18720 (36.05%),\nstep=938,\nnum_rigl_steps=9,\nignoring_linear_layers=False,\nsparsity_distribution=erk,\n)'

In [110]:
pruner_org.state_dict()

{'dense_allocation': 0.5,
 'S': [0.0, 0.6495171015950314, 0.49832840032229986, 0.0],
 'N': [288, 18432, 1179648, 1280],
 'hyperparams': {'delta_T': 100,
  'alpha': 0.3,
  'T_end': 87937,
  'ignore_linear_layers': False,
  'static_topo': 0,
  'sparsity_distribution': 'erk',
  'grad_accumulation_n': 1,
  'erk_power_scale': 1.0},
 'step': 938,
 'rigl_steps': 9,
 'backward_masks': [None,
  tensor([[[[False,  True,  True],
            [ True,  True,  True],
            [ True,  True,  True]],
  
           [[ True,  True,  True],
            [ True,  True,  True],
            [ True,  True,  True]],
  
           [[False, False, False],
            [False, False, False],
            [False, False, False]],
  
           ...,
  
           [[False,  True,  True],
            [False,  True, False],
            [ True,  True,  True]],
  
           [[ True,  True,  True],
            [ True,  True,  True],
            [ True,  True,  True]],
  
           [[ True,  True,  True],
            [F

In [109]:
pruner.state_dict()

{'dense_allocation': 0.5,
 'S': [0.0, 0.6495171015950314, 0.49832840032229986, 0.0],
 'N': [288, 18432, 1179648, 1280],
 'hyperparams': {'delta_T': 100,
  'alpha': 0.3,
  'T_end': 87937,
  'ignore_linear_layers': False,
  'static_topo': 0,
  'sparsity_distribution': 'erk',
  'grad_accumulation_n': 1,
  'erk_power_scale': 1.0},
 'step': 938,
 'rigl_steps': 9,
 'backward_masks': [None,
  tensor([[[[False,  True,  True],
            [ True,  True,  True],
            [ True,  True,  True]],
  
           [[ True,  True,  True],
            [ True,  True,  True],
            [ True,  True,  True]],
  
           [[False, False, False],
            [False, False, False],
            [False, False, False]],
  
           ...,
  
           [[False,  True,  True],
            [False,  True, False],
            [ True,  True,  True]],
  
           [[ True,  True,  True],
            [ True,  True,  True],
            [ True,  True,  True]],
  
           [[ True,  True,  True],
            [F

In [98]:
model_org.state_dict() == model.state_dict()

RuntimeError: Boolean value of Tensor with more than one value is ambiguous

In [92]:
model.state_dict()

OrderedDict([('conv1.weight',
              tensor([[[[ 0.3528,  0.4003,  0.0133],
                        [ 0.4471,  0.0982,  0.1988],
                        [-0.0790,  0.2990,  0.3824]]],
              
              
                      [[[-0.1792,  0.3867,  0.1595],
                        [ 0.3507,  0.1934,  0.2975],
                        [ 0.0377,  0.3706,  0.1597]]],
              
              
                      [[[-0.1724,  0.0647, -0.1727],
                        [-0.0576, -0.1534,  0.2079],
                        [-0.2811, -0.1693, -0.1038]]],
              
              
                      [[[-0.1935,  0.0240, -0.3362],
                        [ 0.3352, -0.2493,  0.2932],
                        [ 0.1011, -0.0552,  0.2610]]],
              
              
                      [[[ 0.1100,  0.3511,  0.0979],
                        [-0.0168,  0.2196,  0.0105],
                        [ 0.2290,  0.4065,  0.2876]]],
              
              
               

In [None]:
import pathlib

checkpoint_dir = pathlib.Path("./")

In [None]:
test="str"
type(test)

str

In [None]:
type(checkpoint_dir)

pathlib.PosixPath

In [None]:
import datetime

datetime.date.today().isoformat()
# datetime.datetime.now()

'2022-08-19'

In [None]:
pruner.__dir__()

['_implemented_sparsity_distributions',
 '_logger',
 'erk_power_scale',
 'model',
 'optimizer',
 'W',
 '_linear_layers_mask',
 'dense_allocation',
 'N',
 'sparsity_distribution',
 'static_topo',
 'grad_accumulation_n',
 'ignore_linear_layers',
 'backward_masks',
 'S',
 'step',
 'rigl_steps',
 'delta_T',
 'alpha',
 'T_end',
 'backward_hook_objects',
 '__module__',
 '__init__',
 '_allocate_sparsity',
 '_uniform_sparsity_dist',
 '_erk_sparsity_dist',
 'state_dict',
 'load_state_dict',
 'random_sparsify',
 '__str__',
 'reset_momentum',
 'apply_mask_to_weights',
 'apply_mask_to_gradients',
 'check_if_backward_hook_should_accumulate_grad',
 'cosine_annealing',
 '__call__',
 '_rigl_step',
 '__dict__',
 '__weakref__',
 '__doc__',
 '__repr__',
 '__hash__',
 '__getattribute__',
 '__setattr__',
 '__delattr__',
 '__lt__',
 '__le__',
 '__eq__',
 '__ne__',
 '__gt__',
 '__ge__',
 '__new__',
 '__reduce_ex__',
 '__reduce__',
 '__subclasshook__',
 '__init_subclass__',
 '__format__',
 '__sizeof__',
 '__d

In [None]:
getattr(pruner, "state_dict2")

AttributeError: 'RigLScheduler' object has no attribute 'state_dict2'

In [None]:
pruner.state_dict()["step"]

782

In [None]:
pruner.state_dict()["rigl_steps"]

7

In [None]:
pruner.state_dict().keys()

dict_keys(['dense_allocation', 'S', 'N', 'hyperparams', 'step', 'rigl_steps', 'backward_masks', '_linear_layers_mask'])

In [None]:
print(pruner.state_dict()["hyperparams"])

{'delta_T': 100, 'alpha': 0.3, 'T_end': 73312, 'ignore_linear_layers': False, 'static_topo': 0, 'sparsity_distribution': 'erk', 'grad_accumulation_n': 1}


In [None]:
print(pruner)

RigLScheduler(
layers=21,
nonzero_params=[1728/1728, 36864/36864, 36864/36864, 36864/36864, 36864/36864, 73728/73728, 147456/147456, 8192/8192, 147456/147456, 147456/147456, 294912/294912, 424959/589824, 32768/32768, 424959/589824, 424959/589824, 634976/1179648, 844994/2359296, 131072/131072, 844994/2359296, 844994/2359296, 5120/5120],
nonzero_percentages=[100.00%, 100.00%, 100.00%, 100.00%, 100.00%, 100.00%, 100.00%, 100.00%, 100.00%, 100.00%, 100.00%, 72.05%, 100.00%, 72.05%, 72.05%, 53.83%, 35.82%, 100.00%, 35.82%, 35.82%, 100.00%],
total_nonzero_params=5582179/11164352 (50.00%),
total_CONV_nonzero_params=5577059/11159232 (49.98%),
step=0,
num_rigl_steps=0,
ignoring_linear_layers=False,
sparsity_distribution=erk,
)


In [None]:
from rigl_torch.util import get_W
W = get_W(model, return_linear_layers_mask=False)
for w in W:
    print(w.shape)
    print(w.numel())

torch.Size([64, 3, 3, 3])
1728
torch.Size([64, 64, 3, 3])
36864
torch.Size([64, 64, 3, 3])
36864
torch.Size([64, 64, 3, 3])
36864
torch.Size([64, 64, 3, 3])
36864
torch.Size([128, 64, 3, 3])
73728
torch.Size([128, 128, 3, 3])
147456
torch.Size([128, 64, 1, 1])
8192
torch.Size([128, 128, 3, 3])
147456
torch.Size([128, 128, 3, 3])
147456
torch.Size([256, 128, 3, 3])
294912
torch.Size([256, 256, 3, 3])
589824
torch.Size([256, 128, 1, 1])
32768
torch.Size([256, 256, 3, 3])
589824
torch.Size([256, 256, 3, 3])
589824
torch.Size([512, 256, 3, 3])
1179648
torch.Size([512, 512, 3, 3])
2359296
torch.Size([512, 256, 1, 1])
131072
torch.Size([512, 512, 3, 3])
2359296
torch.Size([512, 512, 3, 3])
2359296
torch.Size([10, 512])
5120


In [None]:
total_el = 0
non_zero_el = 0
for mask, weights in list(zip(pruner.backward_masks, pruner.W)):
    if mask is None:
        total_el += weights.numel()
        non_zero_el += weights.numel()
    else:
        total_el +=weights.numel()
        non_zero_el+=mask.sum()

In [None]:
total_el

11164352

In [None]:
non_zero_el

tensor(5582179, device='cuda:0')

In [None]:
non_zero_el / total_el

tensor(0.5000, device='cuda:0')

In [None]:
pruner.state_dict().keys()

dict_keys(['dense_allocation', 'S', 'N', 'hyperparams', 'step', 'rigl_steps', 'backward_masks', '_linear_layers_mask'])

In [None]:
state_dict = pruner.state_dict()
state_dict["_linear_layers_mask"]

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]

In [None]:
pruner.backward_masks

[None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 tensor([[[[ True, False,  True],
           [ True,  True, False],
           [False, False,  True]],
 
          [[ True, False, False],
           [ True,  True,  True],
           [False,  True,  True]],
 
          [[False, False,  True],
           [False,  True, False],
           [ True, False,  True]],
 
          ...,
 
          [[ True,  True,  True],
           [ True,  True,  True],
           [ True, False,  True]],
 
          [[False,  True,  True],
           [False,  True,  True],
           [False,  True, False]],
 
          [[ True,  True,  True],
           [ True,  True, False],
           [False,  True,  True]]],
 
 
         [[[ True, False, False],
           [ True,  True,  True],
           [ True,  True,  True]],
 
          [[False,  True, False],
           [ True,  True,  True],
           [False, False,  True]],
 
          [[ True,  True,  True],
           [ True, False,  Tr