<a href="https://colab.research.google.com/github/majauhar/DL_MVA/blob/main/SISR_Efficiency.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# For local environments:
# Local installation on Colab results into non-availability of submodules
# !git clone https://github.com/majauhar/fvcore.git
# !pip install -e fvcore
# !pip install einops

In [2]:
"""
For Colab environments:
Wouldn't work for OMNI-SR because of a bug in the original package
Which I have fixed in my fork.
"""
# %pip install fvcore -q
# %pip install einops -q

"\nFor Colab environments:\nWouldn't work for OMNI-SR because of a bug in the original package\nWhich I have fixed in my fork.\n"

In [3]:
# !git clone https://github.com/majauhar/DL_MVA.git
# cd DL_MVA/
# !git clone https://github.com/hellloxiaotian/LESRCNN.git

In [1]:
import torch
import numpy as np
from fvcore.nn import FlopCountAnalysis
from fvcore.nn import flop_count_str
from fvcore.nn import flop_count_table
from time import perf_counter

In [2]:
# Local imports
from utils.efficiency_results import get_model_flops, get_model_activation
from lesrcnn.model import Net
from omni.model import OmniSR

In [3]:
def forward_inference(model, input):
    """
        Little function for calculating inference time
        Averages over 100 inferences
    """
    start_time = perf_counter()
    _ = model(input)
    end_time = perf_counter()
    delta = end_time - start_time
    
    return delta

In [4]:
# model = Net() # LESRCNN 
model = OmniSR() # Omni-SR network

window_size: 8
with_pe True
ffn_bias: 1
window_size: 8
with_pe True
ffn_bias: 1
window_size: 8
with_pe True
ffn_bias: 1
window_size: 8
with_pe True
ffn_bias: 1
window_size: 8
with_pe True
ffn_bias: 1


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [5]:
deltas = []
for _ in range(100):
    input = torch.randn(1, 3, 256, 256)
    deltas.append(forward_inference(model, input))

average_time = np.array(deltas).mean()
print("inference time: {:.2f} ms".format(average_time * 1e3))

In [9]:
"""
To find the number of activations.
Model summary tools based on NTIRE challenge on efficient super-resolution: https://cvlai.net/ntire/2023/
"""
input_dim = (3, 256, 256)
activations, num_conv = get_model_activation(model, input_dim)
activations = activations / 10 ** 6
print("{:>16s} : {:<.4f} [M]".format("#Activations", activations))
print("{:>16s} : {:<d}".format("#Conv2d", num_conv))


flops = get_model_flops(model, input_dim, False)
flops = flops / 10 ** 9
print("{:>16s} : {:<.4f} [G]".format("FLOPs", flops))

num_parameters = sum(map(lambda x: x.numel(), model.parameters()))
num_parameters = num_parameters / 10 ** 6
print("{:>16s} : {:<.4f} [M]".format("#Params", num_parameters))

    #Activations : 173.7359 [M]
         #Conv2d : 26
           FLOPs : 80.1813 [G]
         #Params : 0.6263 [M]


In [10]:
"""
Alternative tool for calculating FLOPs: Fvcore by Facebook research
https://github.com/facebookresearch/fvcore
"""

input = torch.randn(1, 3, 256, 256)
flops = FlopCountAnalysis(model, input)
print("FLOPs: {:.2f} [G]".format(flops.total() / 1e9))

Unsupported operator aten::add encountered 9 time(s)
Unsupported operator aten::pixel_shuffle encountered 2 time(s)


FLOPs: {.3f} [G] 80.026075136


In [11]:
# Layer-wise statistics
print(flop_count_table(flops))

| module                      | #parameters or shape   | #flops   |
|:----------------------------|:-----------------------|:---------|
| model                       | 0.626M                 | 80.026G  |
|  sub_mean.shifter           |  12                    |  0.59M   |
|   sub_mean.shifter.weight   |   (3, 3, 1, 1)         |          |
|   sub_mean.shifter.bias     |   (3,)                 |          |
|  add_mean.shifter           |  12                    |  2.359M  |
|   add_mean.shifter.weight   |   (3, 3, 1, 1)         |          |
|   add_mean.shifter.bias     |   (3,)                 |          |
|  conv1.0                    |  1.728K                |  0.113G  |
|   conv1.0.weight            |   (64, 3, 3, 3)        |          |
|  conv2.0                    |  36.864K               |  2.416G  |
|   conv2.0.weight            |   (64, 64, 3, 3)       |          |
|  conv3.0                    |  4.096K                |  0.268G  |
|   conv3.0.weight            |   (64, 64, 1, 1)

In [12]:
print(flop_count_str(flops))

Net(
  #params: 0.63M, #flops: 80.03G
  (sub_mean): MeanShift(
    #params: 12, #flops: 0.59M
    (shifter): Conv2d(
      3, 3, kernel_size=(1, 1), stride=(1, 1)
      #params: 12, #flops: 0.59M
    )
  )
  (add_mean): MeanShift(
    #params: 12, #flops: 2.36M
    (shifter): Conv2d(
      3, 3, kernel_size=(1, 1), stride=(1, 1)
      #params: 12, #flops: 2.36M
    )
  )
  (conv1): Sequential(
    #params: 1.73K, #flops: 0.11G
    (0): Conv2d(
      3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
      #params: 1.73K, #flops: 0.11G
    )
  )
  (conv2): Sequential(
    #params: 36.86K, #flops: 2.42G
    (0): Conv2d(
      64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
      #params: 36.86K, #flops: 2.42G
    )
    (1): ReLU(inplace=True)
  )
  (conv3): Sequential(
    #params: 4.1K, #flops: 0.27G
    (0): Conv2d(
      64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False
      #params: 4.1K, #flops: 0.27G
    )
  )
  (conv4): Sequential(
    