In [2]:
import torch
import torch.nn as nn
from torch import Tensor
import torch.nn.functional as F
from torchsummary import summary
from torchvision.models import mobilenet_v3_large,mobilenet_v3_small
from utils import get_config
from modelv5 import MobileNetV3UNet as MobileNetV3UNetv5
from modelv6 import MobileNetV3UNet as MobileNetV3UNetv6
from modelv7 import MobileNetV3UNet as MobileNetV3UNetv7

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Example: Now we can create a model with 4 channels input
model = MobileNetV3UNetv5(in_channels=4, out_channels=1, config_name="large", backbone=True).to(device)

dummy_input = torch.randn(1, 4, 112, 112).to(device)
print(model)
summary(model, input_size=(4,112,112), device=str(device))
output = model(dummy_input)
print("Output shape:", output.shape)



MobileNetV3UNet(
  (encoder): MobileNetV3Encoder(
    (backbone_model): MobileNetV3(
      (features): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): Hardswish()
        )
        (1): InvertedResidual(
          (block): Sequential(
            (0): Conv2dNormActivation(
              (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=16, bias=False)
              (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
              (2): ReLU(inplace=True)
            )
            (1): Conv2dNormActivation(
              (0): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
            )
          )
        )
        

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MobileNetV3UNetv6(in_channels=4, out_channels=1, backbone_pretrained=True).to(device)

dummy_input = torch.randn(1, 4, 112, 112).to(device)
output = model(dummy_input)
print(model)
summary(model, input_size=(4,112,112), device=str(device))
print("Output shape:", output.shape)

MobileNetV3UNet(
  (mask_adapter): Sequential(
    (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (initial_image_conv): Sequential(
    (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
    (2): Hardswish()
  )
  (encoder): MobileNetV3Encoder(
    (backbone_model): MobileNetV3(
      (features): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): Hardswish()
        )
        (1): InvertedResidual(
          (block): Sequential(
            (0): Conv2dNormActivation(
              (0): Conv2d(16, 16, kernel_size=(3, 

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MobileNetV3UNetv7(in_channels=4, out_channels=1, config_name="large", backbone=True).to(device)

dummy_input = torch.randn(1, 4, 112, 112).to(device)
output = model(dummy_input)
print(model)
summary(model, input_size=(4,112,112), device=str(device))
print("Output shape:", output.shape)



MobileNetV3UNet(
  (mask_adapter): Sequential(
    (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (initial_image_conv): Sequential(
    (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
    (2): Hardswish()
  )
  (encoder): MobileNetV3Encoder(
    (backbone_model): MobileNetV3(
      (features): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): Hardswish()
        )
        (1): InvertedResidual(
          (block): Sequential(
            (0): Conv2dNormActivation(
              (0): Conv2d(16, 16, kernel_size=(3, 