In [1]:
%matplotlib inline
import numpy as np
import math
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from rigl_torch.models import ModelFactory
from rigl_torch.optim.cosine_annealing_with_linear_warm_up import CosineAnnealingWithLinearWarmUp
from rigl_torch.rigl_constant_fan import RigLConstFanScheduler
from rigl_torch.rigl_scheduler import RigLScheduler

In [2]:
from rigl_torch.datasets import get_dataloaders
from omegaconf import DictConfig
import hydra

with hydra.initialize(config_path="../configs"):
    cfg = hydra.compose(config_name="config.yaml", overrides=[])
cfg

The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  with hydra.initialize(config_path="../configs"):


{'dataset': {'name': 'cifar10', 'normalize': False, 'num_classes': 10, 'classes': ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']}, 'model': {'name': 'wide_resnet22'}, 'experiment': {'comment': 'erk_testing', 'name': '${model.name}_${dataset.name}_${experiment.comment}'}, 'paths': {'data_folder': '/home/condensed-sparsity/data', 'artifacts': '/home/condensed-sparsity/artifacts', 'logs': '/home/condensed-sparsity/logs'}, 'rigl': {'dense_allocation': 0.1, 'delta': 100, 'grad_accumulation_n': 1, 'alpha': 0.3, 'static_topo': 0, 'const_fan_in': False, 'sparsity_distribution': 'erk', 'erk_power_scale': 1.0}, 'training': {'batch_size': 64, 'test_batch_size': 10, 'epochs': 50, 'lr': 0.1, 'init_lr': 1e-06, 'warm_up_steps': 5, 'gamma': 0.7, 'dry_run': False, 'seed': 1, 'log_interval': 10, 'save_model': True, 'weight_decay': 0, 'momentum': 0.9, 'optimizer': 'adadelta'}, 'compute': {'no_cuda': False, 'cuda_kwargs': {'num_workers': 1, 'pin_memory': True, '

In [3]:
cfg.rigl.sparsity_distribution = "erk"

In [7]:
# model = ModelFactory.load_model(model="mnist", dataset='mnist').to(device)
use_cuda = not cfg.compute.no_cuda and torch.cuda.is_available()
torch.manual_seed(cfg.training.seed)
device = torch.device("cuda" if use_cuda else "cpu")
train_loader, test_loader = get_dataloaders(cfg)

model = ModelFactory.load_model(model=cfg.model.name, dataset=cfg.dataset.name).to(device)
# model = get_model(cfg).to(device)
optimizer = torch.optim.Adadelta(model.parameters(), lr=cfg.training.lr)
scheduler = CosineAnnealingWithLinearWarmUp(
    optimizer,
    T_max=cfg.training.epochs,
    eta_min=0,
    lr=cfg.training.lr,
    warm_up_steps=cfg.training.warm_up_steps,
)

pruner = lambda: True  # noqa: E731
if cfg.rigl.dense_allocation is not None:
    T_end = int(0.75 * cfg.training.epochs * len(train_loader))
    if cfg.rigl.const_fan_in:
        rigl_scheduler = RigLConstFanScheduler
    else:
        rigl_scheduler = RigLScheduler
    pruner = rigl_scheduler(
        model,
        optimizer,
        dense_allocation=cfg.rigl.dense_allocation,
        alpha=cfg.rigl.alpha,
        delta=cfg.rigl.delta,
        static_topo=cfg.rigl.static_topo,
        T_end=T_end,
        ignore_linear_layers=False,
        grad_accumulation_n=cfg.rigl.grad_accumulation_n,
        sparsity_distribution=cfg.rigl.sparsity_distribution,
        erk_power_scale=cfg.rigl.erk_power_scale,
    )
else:
    print(
        "cfg.rigl.dense_allocation is `null`, training with dense "
        "network..."
    )

Files already downloaded and verified


In [8]:
print(pruner)

RigLScheduler(
layers=23,
nonzero_params=[432/432, 4608/4608, 9216/9216, 512/512, 9216/9216, 9216/9216, 9216/9216, 9216/9216, 18432/18432, 36864/36864, 2048/2048, 36864/36864, 36864/36864, 36864/36864, 36864/36864, 73728/73728, 21537/147456, 8192/8192, 21537/147456, 21537/147456, 21537/147456, 21537/147456, 1280/1280],
nonzero_percentages=[100.00%, 100.00%, 100.00%, 100.00%, 100.00%, 100.00%, 100.00%, 100.00%, 100.00%, 100.00%, 100.00%, 100.00%, 100.00%, 100.00%, 100.00%, 100.00%, 14.61%, 100.00%, 14.61%, 14.61%, 14.61%, 14.61%, 100.00%],
total_nonzero_params=447317/1076912 (41.54%),
total_CONV_nonzero_params=446037/1075632 (41.47%),
step=0,
num_rigl_steps=0,
ignoring_linear_layers=False,
sparsity_distribution=erk,
)


In [4]:
from rigl_torch.util import get_W
W = get_W(model, return_linear_layers_mask=False)
for w in W:
    print(w.shape)
    print(w.numel())

torch.Size([16, 3, 3, 3])
432
torch.Size([32, 16, 3, 3])
4608
torch.Size([32, 32, 3, 3])
9216
torch.Size([32, 16, 1, 1])
512
torch.Size([32, 32, 3, 3])
9216
torch.Size([32, 32, 3, 3])
9216
torch.Size([32, 32, 3, 3])
9216
torch.Size([32, 32, 3, 3])
9216
torch.Size([64, 32, 3, 3])
18432
torch.Size([64, 64, 3, 3])
36864
torch.Size([64, 32, 1, 1])
2048
torch.Size([64, 64, 3, 3])
36864
torch.Size([64, 64, 3, 3])
36864
torch.Size([64, 64, 3, 3])
36864
torch.Size([64, 64, 3, 3])
36864
torch.Size([128, 64, 3, 3])
73728
torch.Size([128, 128, 3, 3])
147456
torch.Size([128, 64, 1, 1])
8192
torch.Size([128, 128, 3, 3])
147456
torch.Size([128, 128, 3, 3])
147456
torch.Size([128, 128, 3, 3])
147456
torch.Size([128, 128, 3, 3])
147456
torch.Size([10, 128])
1280


In [5]:
total_el = 0
non_zero_el = 0
for mask, weights in list(zip(pruner.backward_masks, pruner.W)):
    if mask is None:
        total_el += weights.numel()
        non_zero_el += weights.numel()
    else:
        total_el +=weights.numel()
        non_zero_el+=mask.sum()

In [6]:
total_el

1199648

In [20]:
non_zero_el

tensor(121534, device='cuda:0')

In [7]:
non_zero_el / total_el

tensor(0.1014, device='cuda:0')

In [9]:
print(pruner)

RigLScheduler(
layers=4,
nonzero_params=[288/288, 6848/18432, 113280/1179648, 1280/1280],
nonzero_percentages=[100.00%, 37.15%, 9.60%, 100.00%],
total_nonzero_params=121696/1199648 (10.14%),
total_CONV_nonzero_params=7136/18720 (38.12%),
step=0,
num_rigl_steps=0,
ignoring_linear_layers=False,
sparsity_distribution=erk,
constant fan ins=[9, 107, 885, 128]
)


In [4]:
pruner.backward_masks

[None,
 tensor([[[[ True,  True,  True],
           [False, False, False],
           [ True,  True, False]],
 
          [[ True, False,  True],
           [False, False, False],
           [False, False, False]],
 
          [[False, False, False],
           [False,  True,  True],
           [False, False,  True]],
 
          ...,
 
          [[False, False, False],
           [ True, False, False],
           [False, False,  True]],
 
          [[ True,  True, False],
           [ True,  True, False],
           [ True, False, False]],
 
          [[False, False, False],
           [False,  True, False],
           [False,  True,  True]]],
 
 
         [[[ True, False, False],
           [False, False,  True],
           [False, False, False]],
 
          [[False,  True, False],
           [ True,  True, False],
           [False, False, False]],
 
          [[ True, False, False],
           [False,  True, False],
           [False, False,  True]],
 
          ...,
 
          [