-
Notifications
You must be signed in to change notification settings - Fork 9
/
counter_flops_params.py
52 lines (46 loc) · 1.74 KB
/
counter_flops_params.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
#!/usr/bin/env python
# coding=utf-8
'''
Author: wjm
Date: 2021-03-20 14:44:14
LastEditTime: 2021-03-22 15:25:02
Description: file content
'''
import sys
# sys.path.insert(1, "/ghome/fuxy/DPFN-master/thop")
# sys.path.insert(1, "/ghome/fuxy/DPFN-master/ptflops")
# sys.path.insert(1, "/ghome/fuxy/DPFN-master/torchsummaryX")
from thop import profile
import importlib, torch
from utils.config import get_config
import math
#from ptflops import get_model_complexity_info
import time
if __name__ == "__main__":
model_name = 'unet_panv2' #1.414 0.086 #hmb 57.515606 2.155652
net_name = model_name.lower()
lib = importlib.import_module('model.' + net_name)
net = lib.Net
cfg = get_config('option.yml')
model = net(
num_channels=4,
base_filter=64,
args=cfg
)
input = torch.randn(1, 4, 32, 32)
input1 = torch.randn(1, 1, 128, 128)
input2 = torch.randn(1, 4, 128, 128)
# macs, params = get_model_complexity_info(model, ((4, 32, 32), (), (1, 128, 128)),
# as_strings=True,print_per_layer_stat=True, verbose=True)
# print('{:<30} {:<8}'.format('Computational complexity: ', macs))
# print('{:<30} {:<8}'.format('Number of parameters: ', params))
# import torchsummaryX
# torchsummaryX.summary(model, [input.cpu(), None, input1.cpu()])
# print("The torchsummary result")
# from torchsummary import summary
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# summary(model.cuda(), [(4, 32, 32), (), (1, 128, 128)])
#
print("The thop result")
flops, params = profile(model, inputs=(input, input2, input1))
print('flops:{:.6f}, params:{:.6f}'.format(flops/(1e9), params/(1e6)))