In [None]:
import sys
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
from thop import profile
from thop import clever_format

# Add the parent directory of 'models' to sys.path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

from models.attention_unet import AttnUNet
from models.unet import UNet

In [None]:
with torch.no_grad():
    device = torch.device("cpu" if torch.cuda.is_available() else "cpu")
    
    input = torch.rand(1, 3, 192, 192).to(device)
    model = AttnUNet(input_channels=3, out_channels=1, channels = [64, 128, 256, 512]).to(device)
    # print(model)
    flops, params = profile(model, (input,))

    print("-" * 30)
    print(f'Flops  = {clever_format(flops, format="%.5f")}')
    print(f'Params = {clever_format(params, format="%.5f")}')

In [None]:
# Test the Vanilla U-Net's FLOPs and parameters
with torch.no_grad():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    input = torch.rand(1, 3, 256, 256).to(device)
    model = UNet(3, 1, [64, 128, 256, 512]).to(device)
    flops, params = profile(model, (input,))

    print("-" * 30)
    print(f'Flops  = {clever_format(flops, format="%.5f")}')
    print(f'Params = {clever_format(params, format="%.5f")}')