In [35]:
import numpy as np
import torch
from torch import nn
import pandas as pd

In [3]:
AlexNet = nn.Sequential(nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=2), nn.ReLU(),
                    nn.MaxPool2d(kernel_size=3, stride=2),
                    nn.Conv2d(96, 256, kernel_size=5, padding=2), nn.ReLU(),
                    nn.MaxPool2d(kernel_size=3, stride=2),
                    nn.Conv2d(256, 384, kernel_size=3, padding=1), nn.ReLU(),
                    nn.Conv2d(384, 384, kernel_size=3, padding=1), nn.ReLU(),
                    nn.Conv2d(384, 256, kernel_size=3, padding=1), nn.ReLU(),
                    nn.MaxPool2d(kernel_size=3, stride=2),
                    nn.Flatten(),
                    nn.Linear(256 * 6 * 6, 4096), nn.ReLU(), nn.Dropout(p=0.5),
                    nn.Linear(4096, 4096), nn.ReLU(), nn.Dropout(p=0.5),
                    nn.Linear(4096, 1000))

In [4]:
X = torch.randn(1, 3, 224, 224)
for layer in AlexNet:
    X=layer(X)
    print(layer.__class__.__name__,'output shape:\t',X.shape)

Conv2d output shape:	 torch.Size([1, 96, 55, 55])
ReLU output shape:	 torch.Size([1, 96, 55, 55])
MaxPool2d output shape:	 torch.Size([1, 96, 27, 27])
Conv2d output shape:	 torch.Size([1, 256, 27, 27])
ReLU output shape:	 torch.Size([1, 256, 27, 27])
MaxPool2d output shape:	 torch.Size([1, 256, 13, 13])
Conv2d output shape:	 torch.Size([1, 384, 13, 13])
ReLU output shape:	 torch.Size([1, 384, 13, 13])
Conv2d output shape:	 torch.Size([1, 384, 13, 13])
ReLU output shape:	 torch.Size([1, 384, 13, 13])
Conv2d output shape:	 torch.Size([1, 256, 13, 13])
ReLU output shape:	 torch.Size([1, 256, 13, 13])
MaxPool2d output shape:	 torch.Size([1, 256, 6, 6])
Flatten output shape:	 torch.Size([1, 9216])
Linear output shape:	 torch.Size([1, 4096])
ReLU output shape:	 torch.Size([1, 4096])
Dropout output shape:	 torch.Size([1, 4096])
Linear output shape:	 torch.Size([1, 4096])
ReLU output shape:	 torch.Size([1, 4096])
Dropout output shape:	 torch.Size([1, 4096])
Linear output shape:	 torch.Size([1,

In [43]:
num_sum = 0
rec = []
for layer in AlexNet:
    if isinstance(layer, (nn.Linear, nn.Conv2d)):        
        w_size = list(layer.weight.shape)
        b_size = layer.bias.shape[0]
        num = np.prod(w_size) + b_size
        num_sum += num
        
        rec.append([layer.__class__.__name__, f"{'*'.join(str(e) for e in w_size)}+{b_size}", num])
print(f'{num_sum} = {round(num_sum / 1000 / 1000, 1)}M')
pd.DataFrame(rec, columns=['name', 'expr', 'parameter nums'])

62378344 = 62.4M


Unnamed: 0,name,expr,parameter nums
0,Conv2d,96*3*11*11+96,34944
1,Conv2d,256*96*5*5+256,614656
2,Conv2d,384*256*3*3+384,885120
3,Conv2d,384*384*3*3+384,1327488
4,Conv2d,256*384*3*3+256,884992
5,Linear,4096*9216+4096,37752832
6,Linear,4096*4096+4096,16781312
7,Linear,1000*4096+1000,4097000


In [45]:
X = torch.randn(1, 3, 224, 224)
sum = 0
rec = []
for layer in AlexNet:
    X = layer(X)
    l_name = layer.__class__.__name__
    num = 0
    expr = ''
    if isinstance(layer, nn.Conv2d):
        w_size = list(layer.weight.shape)
        b_size = layer.bias.shape[0]
        num = (np.prod(w_size) + b_size) * np.prod(X.shape[2:])
        
        expr = f"({'*'.join(str(e) for e in w_size)}+{b_size})*{X.shape[2]}*{X.shape[3]}"        
    if isinstance(layer, nn.ReLU):
        num = np.prod(X.shape[1:])
        
        expr = '*'.join(str(e) for e in X.shape[1:])
    if isinstance(layer, nn.MaxPool2d):
        k = layer.kernel_size
        num = np.prod(X.shape[1:]) * k * k 
        
        expr = f"{'*'.join(str(e) for e in X.shape[1:])}*{k}*{k}"
    if isinstance(layer, nn.Linear):
        w_size = list(layer.weight.shape)
        b_size = layer.bias.shape[0]
        num = np.prod(w_size) + b_size
        
        expr = f"{'*'.join(str(e) for e in w_size)}+{b_size}"
 
    if num:
        rec.append([l_name, expr, num])        
    sum += num

print(f'{sum} = {round(sum / 1000 / 1000 / 1000, 1)}G')
pd.DataFrame(rec, columns=['name', 'expr', 'parameter nums'])

1137675816 = 1.1G


Unnamed: 0,name,expr,parameter nums
0,Conv2d,(96*3*11*11+96)*55*55,105705600
1,ReLU,96*55*55,290400
2,MaxPool2d,96*27*27*3*3,629856
3,Conv2d,(256*96*5*5+256)*27*27,448084224
4,ReLU,256*27*27,186624
5,MaxPool2d,256*13*13*3*3,389376
6,Conv2d,(384*256*3*3+384)*13*13,149585280
7,ReLU,384*13*13,64896
8,Conv2d,(384*384*3*3+384)*13*13,224345472
9,ReLU,384*13*13,64896
