In [1]:
import torch
import torch.nn.functional as F
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.utils.tensorboard import SummaryWriter
import pytorch_lightning as pl
import random
import dotenv
import omegaconf
import hydra
import logging
import wandb
from datetime import date
import pathlib
from typing import Dict, Any
from copy import deepcopy

from rigl_torch.models.model_factory import ModelFactory
from rigl_torch.rigl_scheduler import RigLScheduler
from rigl_torch.rigl_constant_fan import RigLConstFanScheduler
from rigl_torch.datasets import get_dataloaders
from rigl_torch.optim import (
    get_optimizer,
    get_lr_scheduler,
)
from rigl_torch.utils.checkpoint import Checkpoint
from rigl_torch.utils.rigl_utils import get_T_end, get_fan_in_after_ablation, get_conv_idx_from_flat_idx

dataset = "cifar10"
model="wide_resnet22"
with hydra.initialize(config_path="../configs"):
    cfg = hydra.compose(
        config_name="config.yaml", 
        overrides=[
            f"dataset={dataset}",
            f"model={model}",
            "compute.distributed=False",
            "rigl.dense_allocation=0.01",
            "rigl.const_fan_in=True",
            "rigl.filter_ablation_threshold=0.01",
            "rigl.dynamic_ablation=True", 
            "rigl.static_ablation=False",
            "rigl.min_salient_weights_per_neuron=1",
            "rigl.delta=2",
            "rigl.grad_accumulation_n=1",
        ]
    )
cfg

The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  with hydra.initialize(config_path="../configs"):


{'dataset': {'name': 'cifar10', 'normalize': False, 'num_classes': 10, 'classes': ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'], 'train_len': 50000}, 'model': {'name': 'wide_resnet22', 'activation_function': 'relu'}, 'experiment': {'comment': 'dense_alloc-${rigl.dense_allocation}_const_fan_in-${rigl.const_fan_in}_weight_per_neuron-${rigl.min_salient_weights_per_neuron}_sparse_init-True-torch-init', 'name': '${model.name}_${dataset.name}_${experiment.comment}', 'resume_from_checkpoint': False, 'run_id': None}, 'paths': {'base': '${oc.env:BASE_PATH}', 'data_folder': '${paths.base}/data', 'artifacts': '${paths.base}/artifacts', 'logs': '${paths.base}/logs', 'checkpoints': '${paths.artifacts}/checkpoints'}, 'rigl': {'dense_allocation': 0.01, 'delta': 2, 'grad_accumulation_n': 1, 'alpha': 0.3, 'static_topo': 0, 'const_fan_in': True, 'sparsity_distribution': 'erk', 'erk_power_scale': 1.0, 'use_t_end': True, 'static_ablation': False, 'dynamic_ablat

In [2]:
device = torch.device("cuda:0")
train_loader, test_loader = get_dataloaders(cfg)
model = ModelFactory.load_model(
        model=cfg.model.name, dataset=cfg.dataset.name
    )
model.to(device)
optimizer = get_optimizer(cfg, model, state_dict=None)
scheduler = get_lr_scheduler(cfg, optimizer, state_dict=None)
T_end = get_T_end(cfg, train_loader)
if cfg.rigl.const_fan_in:
    rigl_scheduler = RigLConstFanScheduler
else:
    rigl_scheduler = RigLScheduler
pruner = rigl_scheduler(
    model,
    optimizer,
    dense_allocation=cfg.rigl.dense_allocation,
    alpha=cfg.rigl.alpha,
    delta=cfg.rigl.delta,
    static_topo=cfg.rigl.static_topo,
    T_end=T_end,
    ignore_linear_layers=False,
    grad_accumulation_n=cfg.rigl.grad_accumulation_n,
    sparsity_distribution=cfg.rigl.sparsity_distribution,
    erk_power_scale=cfg.rigl.erk_power_scale,
    state_dict=None,
    filter_ablation_threshold=cfg.rigl.filter_ablation_threshold,
    static_ablation=cfg.rigl.static_ablation,
    dynamic_ablation=cfg.rigl.dynamic_ablation,
    min_salient_weights_per_neuron=cfg.rigl.min_salient_weights_per_neuron,
    )

Files already downloaded and verified


INFO:/home/user/condensed-sparsity/src/rigl_torch/models/model_factory.py:Loading model wide_resnet22/cifar10 using <function get_wide_resnet_22 at 0x7f5cb2ba51f0> with args: () and kwargs: {}


In [8]:
from rigl_torch.utils.rigl_utils import get_W, get_names_and_W

W = get_W(model, False)
len(W)

23

In [10]:
names, W_prime = get_names_and_W(model)

In [11]:
assert W == W_prime

In [13]:
names

['conv1',
 'block1.layer.0.conv1',
 'block1.layer.0.conv2',
 'block1.layer.0.convShortcut',
 'block1.layer.1.conv1',
 'block1.layer.1.conv2',
 'block1.layer.2.conv1',
 'block1.layer.2.conv2',
 'block2.layer.0.conv1',
 'block2.layer.0.conv2',
 'block2.layer.0.convShortcut',
 'block2.layer.1.conv1',
 'block2.layer.1.conv2',
 'block2.layer.2.conv1',
 'block2.layer.2.conv2',
 'block3.layer.0.conv1',
 'block3.layer.0.conv2',
 'block3.layer.0.convShortcut',
 'block3.layer.1.conv1',
 'block3.layer.1.conv2',
 'block3.layer.2.conv1',
 'block3.layer.2.conv2',
 'fc']