# Problem2

In [None]:
import torch
import torch.nn as nn


class AlexNet(nn.Module):
    def __init__(self, num_classes: int = 1000) :
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
        self.classifier = nn.Sequential(
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x: torch.Tensor) :
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x



model = AlexNet()

conv_params = sum(p.numel() for p in model.features.parameters() if p.requires_grad)
print('conv_params: ', conv_params)
linear_params = sum(p.numel() for p in model.classifier.parameters() if p.requires_grad)
print('linear_params: ', linear_params)
print('conv_params+linear_params: ', conv_params + linear_params)
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('total_params: ', total_params)
print('conv_params: {}%'.format((conv_params/total_params)*100))
print('linear_params: {}%: '.format((linear_params/total_params)*100))

# Count Multiplications and Additions
# Convolutions
features = model.features
input_size = 227
multiplications_conv = 0
additions_conv = 0
for feature in features:
    if isinstance(feature, nn.Conv2d):
        output_size = (input_size - feature.kernel_size[0] + feature.padding[0] * 2) // feature.stride[0] + 1
        multiplications_conv += feature.in_channels * feature.out_channels * (feature.kernel_size[0] ** 2) * ((output_size) ** 2)
        input_size = output_size
    if isinstance(feature, nn.MaxPool2d):
        output_size = (input_size - feature.kernel_size) // feature.stride + 1
        input_size = output_size
additions_conv = multiplications_conv # Each multiplication has a corresponding addition
print('Conv Multiplications: ', multiplications_conv)
print('Conv Additions: ', additions_conv)

# Linear Layers
classifier = model.classifier
multiplications_lin = 0
additions_lin = 0
for layer in classifier:
    if isinstance(layer, nn.Linear):
        multiplications_lin += layer.in_features * layer.out_features
        additions_lin += layer.out_features * layer.in_features
print('Linear Multiplications: ', multiplications_lin)
print('Linear Additions: ', additions_lin)
print('Linear Multiplications Percentage: {}%'.format((multiplications_lin/(multiplications_conv + multiplications_lin))*100))
print('Linear Additions Percentage: {}%'.format((additions_lin/(additions_conv + additions_lin))*100))

conv_params:  2469696
linear_params:  58631144
conv_params+linear_params:  61100840
total_params:  61100840
conv_params: 4.042000077249347%
linear_params: 95.95799992275064%: 
Conv Multiplications:  655566528
Conv Additions:  655566528
Linear Multiplications:  58621952
Linear Additions:  58621952
Linear Multiplications Percentage: 8.208190644576066%
Linear Additions Percentage: 8.208190644576066%
