In [1]:
# For inline plotting
%matplotlib inline

# For auto reloading
%load_ext autoreload
%autoreload 2

In [2]:
import torch


gpu_index = 0
if torch.cuda.is_available():
    device = torch.device("cuda:{}".format(gpu_index))
    print('CUDA available. PyTorch version:', torch.__version__, ' Device:', device)
else:
    print('CUDA is not available. Stopped.')
    sys.exit()

CUDA available. PyTorch version: 1.3.1  Device: cuda:0


In [3]:
from models.resnet import ResNet18, ResNet34, ResNet50, ResNet101
from models.resnet_bibd import BResNet18, BResNet34, BResNet50, BResNet101
import numpy as np
import torch


n = 3

print('Building the models...')
model_list = []
model_list.append(ResNet18().to(device))
print('ResNet18 added.')
model_list.append(ResNet34().to(device))
print('ResNet34 added.')
model_list.append(ResNet50().to(device))
print('ResNet50 added.')
# model_list.append(ResNet101().to(device))
# print('ResNet101 added.')
model_list.append(BResNet18().to(device))
print('BResNet18 added.')
model_list.append(BResNet34().to(device))
print('BResNet34 added.')
model_list.append(BResNet50().to(device))
print('BResNet50 added.')
# model_list.append(BResNet101().to(device))
# print('BResNet101 added.')
print('All models built and added to the list.')
print('Model list:')
for model in model_list:
    print('    {}'.format(model.name))


from thop import profile
import sys
sys.path.append('../bibd')
from bibd_layer import BibdLinear, RandomSparseLinear, generate_fake_bibd_mask, BibdConv2d, RandomSparseConv2d, bibd_sparsity


def count_model(model_to_count, x, y):
#     print(x.size())
#     print(x)
    in_features = len(x)
    out_features = len(y)
    
    print('in_features = {}, out_features = {}'.format(in_features, out_features))
    
    # per output element
    total_mul = in_features
    total_add = in_features - 1
    num_elements = y.numel()
    total_ops = (total_mul + total_add) * num_elements
    # one zero weight, minus 2 ops
    total_ops -= (in_features * out_features - np.sum(generate_fake_bibd_mask(in_features, out_features))) * 2

    model_to_count.total_ops += torch.Tensor([int(total_ops)])


def count_bibdConv2d(m, x: (torch.Tensor,), y: torch.Tensor):
    '''
    Reference: https://github.com/Lyken17/pytorch-OpCounter/blob/dbc052ec2d914e950b3e0fa24d3eea753a4e0bb7/thop/vision/basic_hooks.py#L15
    '''
    
    
    x = x[0]

    kernel_ops = torch.zeros(m.fpWeight.size()[2:]).numel()  # Kw x Kh
#     bias_ops = 1 if m.bias is not None else 0
    bias_ops = 0 # For our current implementation of BibdConv2d, bias is always disabled implicitly

    # N x Cout x H x W x  (Cin x Kw x Kh + bias)
    total_ops = y.nelement() * (m.in_channels // m.conGroups * kernel_ops + bias_ops)
    
    # Multiply total_ops with the sparsity
    total_ops *= bibd_sparsity(m.in_channels, m.out_channels)

    m.total_ops += torch.DoubleTensor([int(total_ops)])


custom_ops = {
    BibdConv2d: count_bibdConv2d,
    RandomSparseConv2d: count_bibdConv2d
}


for model in model_list:
    input = torch.randn(1, 3, 32, 32).to(device)
    flops, params = profile(model, inputs=(input, ), custom_ops=custom_ops)
    print('Model: %s, Params: %.4f, FLOPs(M): %.2f' % (model.name, params / (1000 ** 2), flops / (1000 ** 2)))

Building the models...
ResNet18 added.
BResNet18 added.
BResNet101 added.
All models built and added to the list.
Model list:
    ResNet-18
    B-ResNet-18
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_bn() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[91m[WARN] Cannot find rule for <class 'torch.nn.modules.container.Sequential'>. Treat it as zero Macs and zero Params.[00m
[91m[WARN] Cannot find rule for <class 'models.resnet.BasicBlock'>. Treat it as zero Macs and zero Params.[00m
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[91m[WARN] Cannot find rule for <class 'models.resnet.ResNet'>. Treat it as zero Macs and zero Params.[00m
Model: ResNet-18, Params: 11.1740, FLOPs(M): 556.65
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_bn() for <class 'torch.nn.modules.batch