In [6]:
# For inline plotting
%matplotlib inline

# For auto reloading
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
import torch


if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
    
print('Using PyTorch version:', torch.__version__, ' Device:', device)

Using PyTorch version: 1.3.1  Device: cuda


In [9]:
from models.resnet import ResNet18, ResNet34, ResNet50, ResNet101
import numpy as np


model_list = []
model_list.append(ResNet18().to(device))
model_list.append(ResNet34().to(device))
model_list.append(ResNet50().to(device))
model_list.append(ResNet101().to(device))
print('Model list:')
for model in model_list:
    print('    {}'.format(model.name))

from thop import profile
import sys
sys.path.append('../bibd')
from bibd_layer import BibdLinear, RandomSparseLinear, generate_fake_bibd_mask


def count_model(model_to_count, x, y):
#     print(x.size())
#     print(x)
    in_features = len(x)
    out_features = len(y)
    
    print('in_features = {}, out_features = {}'.format(in_features, out_features))
    
    # per output element
    total_mul = in_features
    total_add = in_features - 1
    num_elements = y.numel()
    total_ops = (total_mul + total_add) * num_elements
    # one zero weight, minus 2 ops
    total_ops -= (in_features * out_features - np.sum(generate_fake_bibd_mask(in_features, out_features))) * 2

    model_to_count.total_ops += torch.Tensor([int(total_ops)])

for model in model_list:
    input = torch.randn(1, 3, 32, 32).to(device)
    flops, params = profile(model, inputs=(input, ))
#     flops, params = profile(model, inputs=(input, ), custom_ops={BibdLinear: count_model, RandomSparseLinear: count_model})
    print('Model: %s, Params: %.4f, FLOPs(M): %.2f' % (model.name, params / (1000 ** 2), flops / (1000 ** 2)))

Model list:
    ResNet-18
    ResNet-34
    ResNet-50
    ResNet-101
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_bn() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[91m[WARN] Cannot find rule for <class 'torch.nn.modules.container.Sequential'>. Treat it as zero Macs and zero Params.[00m
[91m[WARN] Cannot find rule for <class 'models.resnet.BasicBlock'>. Treat it as zero Macs and zero Params.[00m
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[91m[WARN] Cannot find rule for <class 'models.resnet.ResNet'>. Treat it as zero Macs and zero Params.[00m
Model: ResNet-18, Params: 11.1740, FLOPs(M): 556.65
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_bn() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.