# Models Classification Segmentation from libraires    
* Platform: Windows, Python, miniconda, PyTorch    
* Date: 2024-09-20    
* Author: Jing Zhang    


## models from TIMM
* Reference: https://huggingface.co/timm

list models

In [8]:
import timm
for m in timm.list_models(): print(m)

bat_resnext26ts
beit_base_patch16_224
beit_base_patch16_384
beit_large_patch16_224
beit_large_patch16_384
beit_large_patch16_512
beitv2_base_patch16_224
beitv2_large_patch16_224
botnet26t_256
botnet50ts_256
caformer_b36
caformer_m36
caformer_s18
caformer_s36
cait_m36_384
cait_m48_448
cait_s24_224
cait_s24_384
cait_s36_384
cait_xs24_384
cait_xxs24_224
cait_xxs24_384
cait_xxs36_224
cait_xxs36_384
coat_lite_medium
coat_lite_medium_384
coat_lite_mini
coat_lite_small
coat_lite_tiny
coat_mini
coat_small
coat_tiny
coatnet_0_224
coatnet_0_rw_224
coatnet_1_224
coatnet_1_rw_224
coatnet_2_224
coatnet_2_rw_224
coatnet_3_224
coatnet_3_rw_224
coatnet_4_224
coatnet_5_224
coatnet_bn_0_rw_224
coatnet_nano_cc_224
coatnet_nano_rw_224
coatnet_pico_rw_224
coatnet_rmlp_0_rw_224
coatnet_rmlp_1_rw2_224
coatnet_rmlp_1_rw_224
coatnet_rmlp_2_rw_224
coatnet_rmlp_2_rw_384
coatnet_rmlp_3_rw_224
coatnet_rmlp_nano_rw_224
coatnext_nano_rw_224
convformer_b36
convformer_m36
convformer_s18
convformer_s36
convit_base
conv

resnet

In [6]:
import timm
from torchinfo import summary

net = timm.create_model(
    'resnet50', 
    pretrained=False,       # no pretrained weights 
    num_classes=10,         # output classes 10
    in_chans=3,             # input channels 
    drop_rate=0.5,          # dropout 0.5
    global_pool='max'       # pooling max
)

summary(net,(1,3,224,224)) # bs, in_channel, image size

Layer (type:depth-idx)                   Output Shape              Param #
ResNet                                   [1, 10]                   --
├─Conv2d: 1-1                            [1, 64, 112, 112]         9,408
├─BatchNorm2d: 1-2                       [1, 64, 112, 112]         128
├─ReLU: 1-3                              [1, 64, 112, 112]         --
├─MaxPool2d: 1-4                         [1, 64, 56, 56]           --
├─Sequential: 1-5                        [1, 256, 56, 56]          --
│    └─Bottleneck: 2-1                   [1, 256, 56, 56]          --
│    │    └─Conv2d: 3-1                  [1, 64, 56, 56]           4,096
│    │    └─BatchNorm2d: 3-2             [1, 64, 56, 56]           128
│    │    └─ReLU: 3-3                    [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-4                  [1, 64, 56, 56]           36,864
│    │    └─BatchNorm2d: 3-5             [1, 64, 56, 56]           128
│    │    └─Identity: 3-6                [1, 64, 56, 56]           --
│ 

fine-tune protocol

In [None]:
import timm
import torch
import torch.nn as nn
import torch.optim as optim

# 1. load pretrained model
model = timm.create_model('resnet50', pretrained=True, num_classes=10)

# 2. freeze pretrained model parameters
for param in model.parameters():
    param.requires_grad = False # will not update the weights of pretrained model

for name, param in model.named_parameters():
    if 'layer4' not in name: # freeze all layers except layer4
        param.requires_grad = False

# 3. only train the last layer
# unfreeze fully connected layer parameters
for param in model.fc.parameters():
    param.requires_grad = True

# 4. loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.fc.parameters(), lr=1e-4)  # only update the classification layer parameters 

# 5. prepare dataset and dataloader

# 6. start training


## models from MONAI
* Reference: https://docs.monai.io/en/stable/networks.html

Basic unet

In [11]:
from monai.networks.nets import BasicUnet
from torchinfo import summary

net = BasicUnet(
    spatial_dims=1, 
    in_channels=2, 
    out_channels=1, 
    features=(16, 32, 64, 128, 256, 16), 
    act=('LeakyReLU', {'inplace': True, 'negative_slope': 0.1}), 
    norm=('instance', {'affine': True}), 
    bias=True, 
    dropout=0.0, 
    upsample='deconv'
    )

summary(net,(1,2,224)) # bs, in_channel, vecter length

BasicUNet features: (16, 32, 64, 128, 256, 16).


Layer (type:depth-idx)                             Output Shape              Param #
BasicUNet                                          [1, 1, 224]               --
├─TwoConv: 1-1                                     [1, 16, 224]              --
│    └─Convolution: 2-1                            [1, 16, 224]              --
│    │    └─Conv1d: 3-1                            [1, 16, 224]              112
│    │    └─ADN: 3-2                               [1, 16, 224]              32
│    └─Convolution: 2-2                            [1, 16, 224]              --
│    │    └─Conv1d: 3-3                            [1, 16, 224]              784
│    │    └─ADN: 3-4                               [1, 16, 224]              32
├─Down: 1-2                                        [1, 32, 112]              --
│    └─MaxPool1d: 2-3                              [1, 16, 112]              --
│    └─TwoConv: 2-4                                [1, 32, 112]              --
│    │    └─Convolution: 3-5     

v-net

In [34]:
from monai.networks.nets import VNet
from torchinfo import summary


net = VNet(
    spatial_dims=2,
    in_channels=2, 
    out_channels=1, 
    act=('elu', {'inplace': True}), 
    dropout_prob=0.5, 
    dropout_prob_down=0.5, 
    dropout_prob_up=(0.5, 0.5), 
    dropout_dim=3, 
    bias=False
   )

summary(net,(1,2,224,224)) # bs, in_channel, vecter length

Layer (type:depth-idx)                             Output Shape              Param #
VNet                                               [1, 1, 224, 224]          --
├─InputTransition: 1-1                             [1, 16, 224, 224]         --
│    └─Convolution: 2-1                            [1, 16, 224, 224]         --
│    │    └─Conv2d: 3-1                            [1, 16, 224, 224]         800
│    │    └─ADN: 3-2                               [1, 16, 224, 224]         32
│    └─ELU: 2-2                                    [1, 16, 224, 224]         --
├─DownTransition: 1-2                              [1, 32, 112, 112]         --
│    └─Conv2d: 2-3                                 [1, 32, 112, 112]         2,048
│    └─BatchNorm2d: 2-4                            [1, 32, 112, 112]         64
│    └─ELU: 2-5                                    [1, 32, 112, 112]         --
│    └─Sequential: 2-6                             [1, 32, 112, 112]         --
│    │    └─LUConv: 3-3        

residual unet

In [1]:
from monai.networks.nets import UNet
from torchinfo import summary

net = UNet(
        spatial_dims=1,
        in_channels=2,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=4
    )

summary(net,(1,2,224)) # bs, in_channel, vecter length

Layer (type:depth-idx)                                                                     Output Shape              Param #
UNet                                                                                       [1, 1, 224]               --
├─Sequential: 1-1                                                                          [1, 1, 224]               --
│    └─ResidualUnit: 2-1                                                                   [1, 16, 112]              --
│    │    └─Conv1d: 3-1                                                                    [1, 16, 112]              112
│    │    └─Sequential: 3-2                                                                [1, 16, 112]              2,468
│    └─SkipConnection: 2-2                                                                 [1, 32, 112]              --
│    │    └─Sequential: 3-3                                                                [1, 16, 112]              1,086,806
│    └─Sequential: 2-3  

unet++

In [12]:
from monai.networks.nets import BasicUnetPlusPlus
from torchinfo import summary

net = BasicUnetPlusPlus(
    spatial_dims=1, 
    in_channels=2,
    out_channels=1, 
    features=(16, 32, 64, 128, 256, 16), 
    deep_supervision=False, 
    act=('LeakyReLU', {'inplace': True, 'negative_slope': 0.1}), 
    norm=('instance', {'affine': True}), 
    bias=True, 
    dropout=0.0, 
    upsample='deconv'
    )
summary(net,(1,2,224)) # bs, in_channel, vecter length

BasicUNetPlusPlus features: (16, 32, 64, 128, 256, 16).


Layer (type:depth-idx)                             Output Shape              Param #
BasicUNetPlusPlus                                  [1, 1, 224]               --
├─TwoConv: 1-1                                     [1, 16, 224]              --
│    └─Convolution: 2-1                            [1, 16, 224]              --
│    │    └─Conv1d: 3-1                            [1, 16, 224]              112
│    │    └─ADN: 3-2                               [1, 16, 224]              32
│    └─Convolution: 2-2                            [1, 16, 224]              --
│    │    └─Conv1d: 3-3                            [1, 16, 224]              784
│    │    └─ADN: 3-4                               [1, 16, 224]              32
├─Down: 1-2                                        [1, 32, 112]              --
│    └─MaxPool1d: 2-3                              [1, 16, 112]              --
│    └─TwoConv: 2-4                                [1, 32, 112]              --
│    │    └─Convolution: 3-5     

attention unet

In [42]:
from monai.networks.nets import AttentionUnet
from torchinfo import summary

net = AttentionUnet(
    spatial_dims=1, 
    in_channels=2, 
    out_channels=1, 
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    kernel_size=3, 
    up_kernel_size=3, 
    dropout=0.0
    )

summary(net,(1,2,224)) # bs, in_channel, vecter length

Layer (type:depth-idx)                                                                Output Shape              Param #
AttentionUnet                                                                         [1, 1, 224]               --
├─Sequential: 1-1                                                                     [1, 1, 224]               --
│    └─ConvBlock: 2-1                                                                 [1, 16, 224]              --
│    │    └─Sequential: 3-1                                                           [1, 16, 224]              960
│    └─AttentionLayer: 2-2                                                            [1, 16, 224]              --
│    │    └─Sequential: 3-2                                                           [1, 32, 112]              675,804
│    │    └─UpConv: 3-3                                                               [1, 16, 224]              1,584
│    │    └─AttentionBlock: 3-4                                   

swin unet

In [13]:
from monai.networks.nets import SwinUNETR
from torchinfo import summary

net = SwinUNETR(
    img_size=(256,256),
    in_channels=2, 
    out_channels=1,
    depths=(2, 2, 2, 2), 
    num_heads=(3, 6, 12, 24), 
    feature_size=24, 
    norm_name='instance', 
    drop_rate=0.0, 
    attn_drop_rate=0.0, 
    dropout_path_rate=0.0, 
    normalize=True, 
    use_checkpoint=False, 
    spatial_dims=2, 
    downsample='merging', 
    use_v2=False
    )
summary(net,(1,2,256,256)) # bs, in_channel, vecter length

Layer (type:depth-idx)                                  Output Shape              Param #
SwinUNETR                                               [1, 1, 256, 256]          --
├─SwinTransformer: 1-1                                  [1, 24, 128, 128]         --
│    └─PatchEmbed: 2-1                                  [1, 24, 128, 128]         --
│    │    └─Conv2d: 3-1                                 [1, 24, 128, 128]         216
│    └─Dropout: 2-2                                     [1, 24, 128, 128]         --
│    └─ModuleList: 2-3                                  --                        --
│    │    └─BasicLayer: 3-2                             [1, 48, 64, 64]           20,262
│    └─ModuleList: 2-4                                  --                        --
│    │    └─BasicLayer: 3-3                             [1, 96, 32, 32]           77,388
│    └─ModuleList: 2-5                                  --                        --
│    │    └─BasicLayer: 3-4                        

transformer unet

In [19]:
from monai.networks.nets import UNETR
from torchinfo import summary

net = UNETR(
    in_channels=3,
    out_channels=1, 
    img_size=(256,256),
    feature_size=16, 
    hidden_size=768, 
    mlp_dim=3072, 
    num_heads=12, 
    proj_type='conv', 
    norm_name='instance', 
    conv_block=True, 
    res_block=True, 
    dropout_rate=0.0, 
    spatial_dims=2, 
    qkv_bias=False, 
    save_attn=False)

summary(net,(1,3,256,256))


Layer (type:depth-idx)                             Output Shape              Param #
UNETR                                              [1, 1, 256, 256]          --
├─ViT: 1-1                                         [1, 256, 768]             --
│    └─PatchEmbeddingBlock: 2-1                    [1, 256, 768]             196,608
│    │    └─Conv2d: 3-1                            [1, 768, 16, 16]          590,592
│    │    └─Dropout: 3-2                           [1, 256, 768]             --
│    └─ModuleList: 2-2                             --                        --
│    │    └─TransformerBlock: 3-3                  [1, 256, 768]             7,085,568
│    │    └─TransformerBlock: 3-4                  [1, 256, 768]             7,085,568
│    │    └─TransformerBlock: 3-5                  [1, 256, 768]             7,085,568
│    │    └─TransformerBlock: 3-6                  [1, 256, 768]             7,085,568
│    │    └─TransformerBlock: 3-7                  [1, 256, 768]             

DiffusionModelUNet

In [1]:
from monai.networks.nets import DiffusionModelUNet
from torchinfo import summary
import torch

net = DiffusionModelUNet(
    spatial_dims=3, 
    in_channels=1, 
    out_channels=1, 
    channels=(32, 32, 32, 32), 
    transformer_num_layers= 1,
    )

x = torch.randn(1, 1, 128, 128, 128)  # bs, in_channel, vecter length
timesteps = torch.randint(1, 5, (1,))  # timestep between 0-10
print(f'timesteps:{timesteps}')
summary(net, input_data=(x, timesteps))

timesteps:tensor([4])


Layer (type:depth-idx)                                  Output Shape              Param #
DiffusionModelUNet                                      [1, 1, 128, 128, 128]     --
├─Sequential: 1-1                                       [1, 128]                  --
│    └─Linear: 2-1                                      [1, 128]                  4,224
│    └─SiLU: 2-2                                        [1, 128]                  --
│    └─Linear: 2-3                                      [1, 128]                  16,512
├─Convolution: 1-2                                      [1, 32, 128, 128, 128]    --
│    └─Conv3d: 2-4                                      [1, 32, 128, 128, 128]    896
├─ModuleList: 1-3                                       --                        --
│    └─DownBlock: 2-5                                   [1, 32, 64, 64, 64]       --
│    │    └─ModuleList: 3-1                             --                        119,232
│    │    └─DiffusionUnetDownsample: 3-2       

SegResNet

In [38]:
from monai.networks.nets import SegResNet
from torchinfo import summary


net = SegResNet(
    spatial_dims=2, 
    init_filters=8, 
    in_channels=3, 
    out_channels=1, 
    dropout_prob=None, 
    act=('RELU', {'inplace': True}), 
    norm=('GROUP', {'num_groups': 8}), 
    norm_name='', 
    num_groups=8, 
    use_conv_final=True, 
    blocks_down=(1, 2, 2, 4), 
    blocks_up=(1, 1, 1), 
    )

summary(net,(1,3,256,256))

Layer (type:depth-idx)                   Output Shape              Param #
SegResNet                                [1, 1, 256, 256]          --
├─Convolution: 1-1                       [1, 8, 256, 256]          --
│    └─Conv2d: 2-1                       [1, 8, 256, 256]          216
├─ModuleList: 1-2                        --                        --
│    └─Sequential: 2-2                   [1, 8, 256, 256]          --
│    │    └─Identity: 3-1                [1, 8, 256, 256]          --
│    │    └─ResBlock: 3-2                [1, 8, 256, 256]          1,184
│    └─Sequential: 2-3                   [1, 16, 128, 128]         --
│    │    └─Convolution: 3-3             [1, 16, 128, 128]         1,152
│    │    └─ResBlock: 3-4                [1, 16, 128, 128]         4,672
│    │    └─ResBlock: 3-5                [1, 16, 128, 128]         4,672
│    └─Sequential: 2-4                   [1, 32, 64, 64]           --
│    │    └─Convolution: 3-6             [1, 32, 64, 64]           4,608

SegResNetVAE

In [40]:
from monai.networks.nets import SegResNetVAE 
from torchinfo import summary

net = SegResNetVAE(
    input_image_size = (256, 256), 
    vae_estimate_std=False, 
    vae_default_std=0.3, 
    vae_nz=256, 
    spatial_dims=2, 
    init_filters=8, 
    in_channels=3,
    out_channels=1, 
    dropout_prob=None, 
    act=('RELU', {'inplace': True}), 
    norm=('GROUP', {'num_groups': 8}), 
    use_conv_final=True, 
    blocks_down=(1, 2, 2, 4), 
    blocks_up=(1, 1, 1), 
    )

summary(net,(1,3,256,256))


Layer (type:depth-idx)                   Output Shape              Param #
SegResNetVAE                             [1, 1, 256, 256]          3,160,871
├─Convolution: 1-1                       [1, 8, 256, 256]          --
│    └─Conv2d: 2-1                       [1, 8, 256, 256]          216
├─ModuleList: 1-2                        --                        --
│    └─Sequential: 2-2                   [1, 8, 256, 256]          --
│    │    └─Identity: 3-1                [1, 8, 256, 256]          --
│    │    └─ResBlock: 3-2                [1, 8, 256, 256]          1,184
│    └─Sequential: 2-3                   [1, 16, 128, 128]         --
│    │    └─Convolution: 3-3             [1, 16, 128, 128]         1,152
│    │    └─ResBlock: 3-4                [1, 16, 128, 128]         4,672
│    │    └─ResBlock: 3-5                [1, 16, 128, 128]         4,672
│    └─Sequential: 2-4                   [1, 32, 64, 64]           --
│    │    └─Convolution: 3-6             [1, 32, 64, 64]         

SegResNetDS

In [41]:
from monai.networks.nets import SegResNetDS
from torchinfo import summary


net = SegResNetDS(
    spatial_dims=2, 
    init_filters=32, 
    in_channels=3, 
    out_channels=1, 
    act='relu', 
    norm='batch', 
    blocks_down=(1, 2, 2, 4), 
    blocks_up=None, 
    dsdepth=1, 
    preprocess=None, 
    upsample_mode='deconv',
    resolution=None
    )

summary(net,(1,3,256,256))


Layer (type:depth-idx)                             Output Shape              Param #
SegResNetDS                                        [1, 1, 256, 256]          --
├─SegResEncoder: 1-1                               [1, 32, 256, 256]         --
│    └─Conv2d: 2-1                                 [1, 32, 256, 256]         864
│    └─ModuleList: 2-2                             --                        --
│    │    └─ModuleDict: 3-1                        --                        36,992
│    │    └─ModuleDict: 3-2                        --                        221,696
│    │    └─ModuleDict: 3-3                        --                        885,760
│    │    └─ModuleDict: 3-4                        --                        4,722,688
├─ModuleList: 1-2                                  --                        --
│    └─ModuleDict: 2-3                             --                        --
│    │    └─UpSample: 3-5                          [1, 128, 64, 64]          294,912
│    │  

## models from Segmentation Models Pytroch 3D
* Reference: https://github.com/ZFTurbo/segmentation_models_pytorch_3d

In [8]:
import segmentation_models_pytorch_3d as smp
from segmentation_models_pytorch_3d.encoders import get_preprocessing_fn
from torchinfo import summary

net = smp.Unet(
    encoder_name="efficientnet-b0", # choose encoder, e.g. resnet34
    in_channels=3,                  # model input channels (1 for gray-scale volumes, 3 for RGB, etc.)
    classes=3,                      # model output channels (number of classes in your dataset)
)

# Preparing your data the same way as during weights pre-training may give you better results
preprocess_input = get_preprocessing_fn('resnet34', pretrained='imagenet')
print(preprocess_input)
summary(net,(1,3,256,256,256)) # batch size, channels, height, width, depth


functools.partial(<function preprocess_input at 0x000001BD6982F740>, input_space='RGB', input_range=[0, 1], mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])


Layer (type:depth-idx)                                  Output Shape              Param #
Unet                                                    [1, 3, 256, 256, 256]     --
├─EfficientNetEncoder: 1-1                              [1, 3, 256, 256, 256]     412,160
│    └─Conv3dStaticSamePadding: 2-1                     [1, 32, 128, 128, 128]    2,592
│    │    └─ZeroPad2d: 3-1                              [1, 3, 257, 257, 257]     --
│    └─BatchNorm3d: 2-2                                 [1, 32, 128, 128, 128]    64
│    └─MemoryEfficientSwish: 2-3                        [1, 32, 128, 128, 128]    --
│    └─ModuleList: 2-4                                  --                        --
│    │    └─MBConvBlock: 3-2                            [1, 16, 128, 128, 128]    2,024
│    │    └─MBConvBlock: 3-3                            [1, 24, 64, 64, 64]       7,732
│    │    └─MBConvBlock: 3-4                            [1, 24, 64, 64, 64]       13,302
│    │    └─MBConvBlock: 3-5              

## models from Segmentation Models Pytroch 2D
* Reference: https://smp.readthedocs.io/en/latest/index.html

In [15]:
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.encoders import get_preprocessing_fn

from torchinfo import summary

# Unet Unet++ FPN PSPNet DeepLabV3 DeepLabV3+ Linknet MAnet PAN

net = smp.Unet(
    encoder_name="resnet34",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=3,                      # model output channels (number of classes in your dataset)
)

# Preparing your data the same way as during weights pre-training may give you better results
preprocess_input = get_preprocessing_fn('resnet34', pretrained='imagenet')
print(preprocess_input)
summary(net,(1,3,256,256))


functools.partial(<function preprocess_input at 0x0000016E9B684040>, input_space='RGB', input_range=[0, 1], mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])


Layer (type:depth-idx)                        Output Shape              Param #
Unet                                          [1, 3, 256, 256]          --
├─ResNetEncoder: 1-1                          [1, 3, 256, 256]          --
│    └─Conv2d: 2-1                            [1, 64, 128, 128]         9,408
│    └─BatchNorm2d: 2-2                       [1, 64, 128, 128]         128
│    └─ReLU: 2-3                              [1, 64, 128, 128]         --
│    └─MaxPool2d: 2-4                         [1, 64, 64, 64]           --
│    └─Sequential: 2-5                        [1, 64, 64, 64]           --
│    │    └─BasicBlock: 3-1                   [1, 64, 64, 64]           73,984
│    │    └─BasicBlock: 3-2                   [1, 64, 64, 64]           73,984
│    │    └─BasicBlock: 3-3                   [1, 64, 64, 64]           73,984
│    └─Sequential: 2-6                        [1, 128, 32, 32]          --
│    │    └─BasicBlock: 3-4                   [1, 128, 32, 32]          230,144