In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
from thop import profile
from thop import clever_format

from Models.attention_unet import AttnUNet
from Models.unet import UNet

In [2]:
with torch.no_grad():
    device = torch.device("cpu" if torch.cuda.is_available() else "cpu")
    
    input = torch.rand(1, 3, 256, 256).to(device)
    model = AttnUNet(input_channels=3, out_channels=1, channels = [64, 128, 256, 512]).to(device)
    # print(model)
    flops, params = profile(model, (input,))

    print("-" * 30)
    print(f'Flops  = {clever_format(flops, format="%.5f")}')
    print(f'Params = {clever_format(params, format="%.5f")}')

[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_upsample() for <class 'torch.nn.modules.upsampling.Upsample'>.
------------------------------
Flops  = 50.43098G
Params = 8.82523M


In [3]:
# Test the Vanilla U-Net's FLOPs and parameters
with torch.no_grad():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    input = torch.rand(1, 3, 256, 256).to(device)
    model = UNet(3, 1, [64, 128, 256, 512]).to(device)
    flops, params = profile(model, (input,))

    print("-" * 30)
    print(f'Flops  = {clever_format(flops, format="%.5f")}')
    print(f'Params = {clever_format(params, format="%.5f")}')

[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_upsample() for <class 'torch.nn.modules.upsampling.Upsample'>.
------------------------------
Flops  = 49.80946G
Params = 8.56403M
