In [1]:
import torch
import torch.nn.functional as F
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.utils.tensorboard import SummaryWriter
import pytorch_lightning as pl
import random
import dotenv
import omegaconf
import hydra
import logging
from typing import List

import wandb
from datetime import date
import dotenv
import os
import pathlib
from typing import Dict, Any
from copy import deepcopy

from rigl_torch.models import ModelFactory
from rigl_torch.rigl_scheduler import RigLScheduler
from rigl_torch.rigl_constant_fan import RigLConstFanScheduler
from rigl_torch.datasets import get_dataloaders
from rigl_torch.optim import (
    get_optimizer,
    get_lr_scheduler,
)
from rigl_torch.utils.rigl_utils import get_names_and_W
from rigl_torch.utils.checkpoint import Checkpoint
from rigl_torch.utils.rigl_utils import get_T_end, get_fan_in_after_ablation, get_conv_idx_from_flat_idx
from hydra import initialize, compose

from fvcore.nn import FlopCountAnalysis
import pandas as pd


In [2]:
def get_pruner_model_loader(dense_alloc, model, dataset):
    with initialize("../configs", version_base="1.2.0"):
        cfg = compose(
            "config.yaml",
            overrides=[
                f"dataset={dataset}",
                "compute.distributed=False",
                f"model={model}",
                # f"rigl.dense_allocation={dense_alloc}",
                f"rigl.dense_allocation={dense_alloc}",
                ])
    dotenv.load_dotenv("../.env")
    os.environ["IMAGE_NET_PATH"]


    rank=0
    checkpoint=None
    if checkpoint is not None:
        run_id = checkpoint.run_id
        optimizer_state = checkpoint.optimizer
        scheduler_state = checkpoint.scheduler
        pruner_state = checkpoint.pruner
        model_state = checkpoint.model
        cfg = checkpoint.cfg
    else:
        run_id, optimizer_state, scheduler_state, pruner_state, model_state = (
            None,
            None,
            None,
            None,
            None,
        )

    print(cfg.compute)
    cfg.compute.distributed=False
        
    pl.seed_everything(cfg.training.seed)
    use_cuda = not cfg.compute.no_cuda and torch.cuda.is_available()
    if not use_cuda:
        raise SystemError("GPU has stopped responding...waiting to die!")
        logger.warning(
            "Using CPU! Verify cfg.compute.no_cuda and "
            "torch.cuda.is_available() are properly set if this is unexpected"
        )

    if cfg.compute.distributed and use_cuda:
        device = torch.device(f"cuda:{rank}")
    else:
        print(f"loading to device rank: {rank}")
        device = torch.device(f"cuda:{rank}")
    if not use_cuda:
        device = torch.device("cuda" if use_cuda else "cpu")
    train_loader, test_loader = get_dataloaders(cfg)

    model = ModelFactory.load_model(
        model=cfg.model.name, dataset=cfg.dataset.name
    )
    model.to(device)
    if cfg.compute.distributed:
        model = DistributedDataParallel(model, device_ids=[rank])
    if model_state is not None:
        try:
            model.load_state_dict(model_state)
        except RuntimeError:
            model_state = checkpoint.get_single_process_model_state_from_distributed_state()
            model.load_state_dict(model_state)
            
    optimizer = get_optimizer(cfg, model, state_dict=optimizer_state)
    scheduler = get_lr_scheduler(cfg, optimizer, state_dict=scheduler_state)
    pruner = None
    if cfg.rigl.dense_allocation is not None:
        T_end = get_T_end(cfg, [0 for _ in range(0,1251)])
        if cfg.rigl.const_fan_in:
            rigl_scheduler = RigLConstFanScheduler
        else:
            rigl_scheduler = RigLScheduler
        pruner = rigl_scheduler(
            model,
            optimizer,
            dense_allocation=cfg.rigl.dense_allocation,
            alpha=cfg.rigl.alpha,
            delta=cfg.rigl.delta,
            static_topo=cfg.rigl.static_topo,
            T_end=T_end,
            ignore_linear_layers=cfg.rigl.ignore_linear_layers,
            grad_accumulation_n=cfg.rigl.grad_accumulation_n,
            sparsity_distribution=cfg.rigl.sparsity_distribution,
            erk_power_scale=cfg.rigl.erk_power_scale,
            state_dict=pruner_state,
            filter_ablation_threshold=cfg.rigl.filter_ablation_threshold,
            static_ablation=cfg.rigl.static_ablation,
            dynamic_ablation=cfg.rigl.dynamic_ablation,
            min_salient_weights_per_neuron=cfg.rigl.min_salient_weights_per_neuron,  # noqa
            use_sparse_init=cfg.rigl.use_sparse_initialization,
            init_method_str=cfg.rigl.init_method_str,
            use_sparse_const_fan_in_for_ablation=cfg.rigl.use_sparse_const_fan_in_for_ablation,  # noqa
        )
        
        step=0
    return pruner, model, train_loader

In [3]:
def get_flops_df(model_name, dataset):
    df = {k:[] for k in ["rigl.dense_allocation", "flops", "model",]}
    for da in [0.01, 0.05, 0.0625, 0.1, 0.2, 0.25,]:
        print(f"Calculating with dense_alloc == {da}")
        pruner, model, train_loader = get_pruner_model_loader(da, model_name, dataset)
        model.eval()
        for data, _ in train_loader:
            data = data[0].to("cpu").reshape(1, *data[0].shape)
            break
        
        flops = FlopCountAnalysis(model.to("cpu"),data)
        total_flops = 0
        S = pruner.S
        names, W = get_names_and_W(model)
        for name, counter in flops.by_module_and_operator().items():
            if name in names:
                if len(counter) != 1:
                    raise ValueError(f"Too many items found in {name}. Goodbye")
                f = list(counter.values())[0]
                s = S[names.index(name)]
                if s is None:
                    s=1
                total_flops += f*(1-s)
        del model
        del pruner
        del train_loader
        df["rigl.dense_allocation"].append(da)
        df["flops"].append(total_flops)
        df["model"].append(model_name)
    return pd.DataFrame(df)

# df = get_flops_df("resnet50", "imagenet")

In [16]:
p, m, l = get_pruner_model_loader("null", "resnet50", "imagenet")

Global seed set to 42


{'no_cuda': False, 'cuda_kwargs': {'num_workers': '${ oc.decode:${oc.env:NUM_WORKERS} }', 'pin_memory': True}, 'distributed': False, 'world_size': 4, 'dist_backend': 'nccl'}
loading to device rank: 0


INFO:/home/mike/condensed-sparsity/src/rigl_torch/models/model_factory.py:Loading model resnet50/imagenet using <function get_imagenet_resnet50 at 0x7f8ac4eb2200> with args: () and kwargs: {}


In [30]:
input=torch.ones(size=(1,3,224,224))
input.shape[-1]

224

In [5]:
from micronet_challenge.counting import *
import torch.nn as nn
Conv2D._fields
 
def get_conv_op(conv: nn.Conv2d, input):
    use_bias = True if conv.bias is not None else False
    c_out, c_in, k_x, k_y = conv.weight.shape
    input=torch.ones(size=(1,3,224,224))
    return Conv2D(
        input_size=input.shape[-1],
        kernel_shape=(k_x, k_y, c_in, c_out),
        strides=conv.stride,
        use_bias=use_bias,
        padding="valid",
        activation="relu",
    )

def get_linear_op(linear: nn.Linear, input, use_relu_activation: bool = True):
    c_out, c_in = m._modules['fc'].weight.shape
    return FullyConnected(
        kernel_shape=(c_in, c_out),
        use_bias = True if linear.bias is not None else False,
        activation="relu" if use_relu_activation else None,
    )
    

In [6]:
m._modules['fc'].weight.shape[0]

1000

In [4]:
def get_flops_df(model_name, dataset):
    df = {k:[] for k in ["rigl.dense_allocation", "flops", "model",]}
    # for da in [0.01, 0.05, 0.0625, 0.1, 0.2, 0.25,]:
    for da in ["null",]:
        print(f"Calculating with dense_alloc == {da}")
        pruner, model, train_loader = get_pruner_model_loader(da, model_name, dataset)
        model.eval()
        model.to("cpu")
        for data, _ in train_loader:
            data = data[0].to("cpu").reshape(1, *data[0].shape)
            break
        
        flops = FlopCountAnalysis(model.to("cpu"), data)
        return flops
    
flops = get_flops_df("resnet50", "imagenet")
flops

Calculating with dense_alloc == null


Global seed set to 42


{'no_cuda': False, 'cuda_kwargs': {'num_workers': '${ oc.decode:${oc.env:NUM_WORKERS} }', 'pin_memory': True}, 'distributed': False, 'world_size': 4, 'dist_backend': 'nccl'}
loading to device rank: 0


INFO:/home/mike/condensed-sparsity/src/rigl_torch/models/model_factory.py:Loading model resnet50/imagenet using <function get_imagenet_resnet50 at 0x7ff6dde82200> with args: () and kwargs: {}


<fvcore.nn.flop_count.FlopCountAnalysis at 0x7ff6dc45b3a0>

In [5]:
flops.total()/1e9



4.111512576

In [25]:
total_flops

1896758355328.0

In [27]:
from typing import List, Optional, NamedTuple

def get_op_from_module(m, input):
    if isinstance(m, nn.Conv2d):
        return get_conv_op(m, input)
    if isinstance(m, nn.Linear):
        return get_linear_op(m, input, use_relu_activation=False)  # Only 1 layer

def get_names_and_ops(
    module,
    input: torch.Tensor,
    target_names: Optional[List[str]]=None,
) -> Dict[str, nn.Module]:
    if target_names is None:
        target_names, _ = get_names_and_W(module)
    names_ops = {k: None for k in target_names}
    
    for n,m in module.named_modules():
        if n in target_names:
            op = get_op_from_module(m, input)
            names_ops[n]=op
    return names_ops


def get_model_info(m, p):
    names = get_names_and_ops(m, input=torch.ones(size=(1,3,224,224)))
    # names

    total_flops = 0
    flops_dict = {n:0 for n in names}
    if p is not None:
        S = p.S
    else:
        S = [0. for _ in range(len(names))]
    total_flops = 0
    total_param_bits = 0
    total_params = 0.
    n_zeros = 0.
    for s, (n, o) in list(zip(S, names.items())):
        param_count, n_mults, n_adds = count_ops(o, s, param_bits=32)
        print(f"{n}: FLOPS: {(n_mults+n_adds)/1e9}")
        k_shape = o.kernel_shape
        total_param_bits += param_count
        total_flops += n_mults + n_adds
        n_param = np.prod(k_shape)
        total_params += n_param
        n_zeros += int(n_param * s)
    return total_flops, total_param_bits, n_zeros / total_params

total_flops, params, global_sparsity = get_model_info(m, p)

conv1: FLOPS: 0.223552896
layer1.0.conv1: FLOPS: 0.411041792
layer1.0.conv2: FLOPS: 3.633610752
layer1.0.conv3: FLOPS: 1.644167168
layer1.0.downsample.0: FLOPS: 1.644167168
layer1.1.conv1: FLOPS: 1.644167168
layer1.1.conv2: FLOPS: 3.633610752
layer1.1.conv3: FLOPS: 1.644167168
layer1.2.conv1: FLOPS: 1.644167168
layer1.2.conv2: FLOPS: 3.633610752
layer1.2.conv3: FLOPS: 1.644167168
layer2.0.conv1: FLOPS: 3.288334336
layer2.0.conv2: FLOPS: 3.633610752
layer2.0.conv3: FLOPS: 6.576668672
layer2.0.downsample.0: FLOPS: 3.288334336
layer2.1.conv1: FLOPS: 6.576668672
layer2.1.conv2: FLOPS: 14.534443008
layer2.1.conv3: FLOPS: 6.576668672
layer2.2.conv1: FLOPS: 6.576668672
layer2.2.conv2: FLOPS: 14.534443008
layer2.2.conv3: FLOPS: 6.576668672
layer2.3.conv1: FLOPS: 6.576668672
layer2.3.conv2: FLOPS: 14.534443008
layer2.3.conv3: FLOPS: 6.576668672
layer3.0.conv1: FLOPS: 13.153337344
layer3.0.conv2: FLOPS: 14.534443008
layer3.0.conv3: FLOPS: 26.306674688
layer3.0.downsample.0: FLOPS: 13.153337344
l

In [18]:
n_o = get_names_and_ops(m, torch.ones(size=(1,3,224,224)))

In [31]:
total_flops/1e9

1896.758355328

In [20]:
n_o

{'conv1': Conv2D(input_size=224, kernel_shape=(7, 7, 3, 64), strides=(2, 2), padding='valid', use_bias=False, activation='relu'),
 'layer1.0.conv1': Conv2D(input_size=224, kernel_shape=(1, 1, 64, 64), strides=(1, 1), padding='valid', use_bias=False, activation='relu'),
 'layer1.0.conv2': Conv2D(input_size=224, kernel_shape=(3, 3, 64, 64), strides=(1, 1), padding='valid', use_bias=False, activation='relu'),
 'layer1.0.conv3': Conv2D(input_size=224, kernel_shape=(1, 1, 64, 256), strides=(1, 1), padding='valid', use_bias=False, activation='relu'),
 'layer1.0.downsample.0': Conv2D(input_size=224, kernel_shape=(1, 1, 64, 256), strides=(1, 1), padding='valid', use_bias=False, activation='relu'),
 'layer1.1.conv1': Conv2D(input_size=224, kernel_shape=(1, 1, 256, 64), strides=(1, 1), padding='valid', use_bias=False, activation='relu'),
 'layer1.1.conv2': Conv2D(input_size=224, kernel_shape=(3, 3, 64, 64), strides=(1, 1), padding='valid', use_bias=False, activation='relu'),
 'layer1.1.conv3': C

In [8]:
total_flops/1e9

1896.758355328

In [12]:
params/32

25503912.0

In [11]:
total=0
for p in m.parameters():
    total+=p.numel()
total

25557032

In [42]:
child.__dict__["_modules"]

OrderedDict()

In [88]:
w_idx = 0
total_flops = 0
modules_to_ignore = []
S = pruner.S
names, W = get_names_and_W(model)
for name, counter in flops.by_module_and_operator().items():
    if name in names:
        if len(counter) != 1:
            raise ValueError("?")
        f = list(counter.values())[0]
        s = S[names.index(name)]
        if s is None:
            s=1
        total_flops += f*(1-s)

total_flops

99517038.63962848

In [85]:
4089284608/1e9

4.089284608

4089284608/1e9

In [35]:
import pandas as pd
pd.DataFrame(params).to_dict()

{'dense_allocation': {0: 0.01,
  1: 0.05,
  2: 0.0625,
  3: 0.1,
  4: 0.2,
  5: 0.25,
  6: 0.3,
  7: 0.4,
  8: 0.5},
 'parameters': {0: 274072,
  1: 1289784,
  2: 1611784,
  3: 2571656,
  4: 5113920,
  5: 6391944,
  6: 7663032,
  7: 10222975,
  8: 12774296},
 'dense_params': {0: 25557032,
  1: 25557032,
  2: 25557032,
  3: 25557032,
  4: 25557032,
  5: 25557032,
  6: 25557032,
  7: 25557032,
  8: 25557032}}