In [1]:
import timm 
import random
from typing import Optional, Union, List, Tuple
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.models as models

from segmentation_models_pytorch.encoders import get_encoder
from segmentation_models_pytorch.base import (
    SegmentationModel,
    SegmentationHead,
    ClassificationHead,
)
from segmentation_models_pytorch.decoders.unet.decoder import UnetDecoder

In [2]:
import segmentation_models_pytorch as smp

In [3]:
encoder = smp.encoders.TimmUniversalEncoder("efficientnetv2_s", in_channels=1, depth=5, pretrained=False)

In [4]:
# 여기서 반토막 짜리 들어간다. 그래서 64 64 이다. 
x = torch.randn(1, 1, 64, 64)
out = encoder(x)
for feature in out:
    print(feature.shape)

torch.Size([1, 1, 64, 64])
torch.Size([1, 24, 32, 32])
torch.Size([1, 48, 16, 16])
torch.Size([1, 64, 8, 8])
torch.Size([1, 160, 4, 4])
torch.Size([1, 256, 2, 2])


In [12]:
timm.list_models()

['adv_inception_v3',
 'bat_resnext26ts',
 'botnet26t_256',
 'botnet50ts_256',
 'cait_m36_384',
 'cait_m48_448',
 'cait_s24_224',
 'cait_s24_384',
 'cait_s36_384',
 'cait_xs24_384',
 'cait_xxs24_224',
 'cait_xxs24_384',
 'cait_xxs36_224',
 'cait_xxs36_384',
 'coat_lite_mini',
 'coat_lite_small',
 'coat_lite_tiny',
 'coat_mini',
 'coat_tiny',
 'convit_base',
 'convit_small',
 'convit_tiny',
 'cspdarknet53',
 'cspdarknet53_iabn',
 'cspresnet50',
 'cspresnet50d',
 'cspresnet50w',
 'cspresnext50',
 'cspresnext50_iabn',
 'darknet53',
 'deit_base_distilled_patch16_224',
 'deit_base_distilled_patch16_384',
 'deit_base_patch16_224',
 'deit_base_patch16_384',
 'deit_small_distilled_patch16_224',
 'deit_small_patch16_224',
 'deit_tiny_distilled_patch16_224',
 'deit_tiny_patch16_224',
 'densenet121',
 'densenet121d',
 'densenet161',
 'densenet169',
 'densenet201',
 'densenet264',
 'densenet264d_iabn',
 'densenetblur121d',
 'dla34',
 'dla46_c',
 'dla46x_c',
 'dla60',
 'dla60_res2net',
 'dla60_res2n

In [6]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [29]:
encoder = get_encoder(
            "resnet34",
            in_channels=1,
            depth=5,
            weights=None,
        )
x = torch.randn(1, 1, 128, 128)
out = encoder(x)
for feature in out:
    print(feature.shape)

torch.Size([1, 1, 128, 128])
torch.Size([1, 64, 64, 64])
torch.Size([1, 64, 32, 32])
torch.Size([1, 128, 16, 16])
torch.Size([1, 256, 8, 8])
torch.Size([1, 512, 4, 4])


In [30]:
encoder.out_channels

(1, 64, 64, 128, 256, 512)

In [5]:



class CustomUnet(smp.Unet):
    

    def __init__(
        self,
        encoder_name: str = "resnet34",
        encoder_depth: int = 5,
        encoder_weights: Optional[str] = "imagenet",
        decoder_use_batchnorm: bool = True,
        decoder_channels: List[int] = (256, 128, 64, 32, 16),
        decoder_attention_type: Optional[str] = None,
        in_channels: int = 3,
        classes: int = 1,
        activation: Optional[Union[str, callable]] = None,
        aux_params: Optional[dict] = None,
    ):
        super().__init__()

        if encoder_name == 'tf_efficientnetv2_b0':
            self.encoder = smp.encoders.TimmUniversalEncoder(encoder_name, in_channels=in_channels, depth=encoder_depth, pretrained=encoder_weights is not None)
            encoder_channels = [1, 24, 48, 64, 160, 256]
            
        else:  
            self.encoder = get_encoder(
                encoder_name,
                in_channels=in_channels,
                depth=encoder_depth,
                weights=encoder_weights,
            )
            encoder_channels = self.encoder.out_channels
        
        

        self.decoder = UnetDecoder(
            encoder_channels=encoder_channels,
            decoder_channels=decoder_channels,
            n_blocks=encoder_depth,
            use_batchnorm=decoder_use_batchnorm,
            center=True if encoder_name.startswith("vgg") else False,
            attention_type=decoder_attention_type,
        )

        self.segmentation_head = SegmentationHead(
            in_channels=decoder_channels[-1],
            out_channels=classes,
            activation=activation,
            kernel_size=3,
        )

        if aux_params is not None:
            self.classification_head = ClassificationHead(in_channels=self.encoder.out_channels[-1], **aux_params)
        else:
            self.classification_head = None

        self.name = "u-{}".format(encoder_name)
        self.initialize()

model = CustomUnet(encoder_name="tf_efficientnetv2_b0", encoder_weights=None, decoder_channels=[256, 128, 64, 32, 16], in_channels=1)
# model = CustomUnet(encoder_name="resnet18", encoder_weights=None, in_channels=1)

out = model(torch.randn(1, 1, 128, 128))
out.shape

RuntimeError: Given groups=1, weight of size [256, 416, 3, 3], expected input[1, 304, 8, 8] to have 416 channels, but got 304 channels instead

In [19]:
model

CustomUnet(
  (encoder): TimmUniversalEncoder(
    (model): EfficientNetFeatures(
      (conv_stem): Conv2dSame(1, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
      (bn1): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      (act1): SiLU(inplace=True)
      (blocks): Sequential(
        (0): Sequential(
          (0): ConvBnAct(
            (conv): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn1): BatchNorm2d(16, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (act1): SiLU(inplace=True)
          )
        )
        (1): Sequential(
          (0): EdgeResidual(
            (conv_exp): Conv2dSame(16, 64, kernel_size=(3, 3), stride=(2, 2), bias=False)
            (bn1): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (act1): SiLU(inplace=True)
            (se): Identity()
            (conv_pwl): Conv2d(64, 32, kernel_size=(1, 1), s

In [26]:
from efficientnet_pytorch import EfficientNet
from efficientnet_pytorch.utils import Conv2dStaticSamePadding
from timm.models.layers import Conv2dSame

def get_same_padding(x: int, k: int, s: int, d: int):
    return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0)

def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1, 1), value: float = 0):
    ip, ih, iw = x.size()[-3:]
    pad_p, pad_h, pad_w = get_same_padding(ip, k[0], s[0], d[0]), get_same_padding(ih, k[1], s[1], d[1]), get_same_padding(iw, k[2], s[2], d[2])
    if pad_p > 0 or pad_h > 0 or pad_w > 0:
        x = F.pad(x, [pad_p // 2, pad_p - pad_p // 2, pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], value=value)
    return x

def conv3d_same(
        x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1, 1),
        padding: Tuple[int, int] = (0, 0, 0), dilation: Tuple[int, int] = (1, 1, 1), groups: int = 1):
    x = pad_same(x, weight.shape[-3:], stride, dilation)
    return F.conv3d(x, weight, bias, stride, (0, 0, 0), dilation, groups)


class Conv3dSame(nn.Conv2d):
    """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions
    """

    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True):
        super(Conv3dSame, self).__init__(
            in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)

    def forward(self, x):
        return conv3d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)


def convert_3d(module):

    new_module = module
    if isinstance(module, Conv2dStaticSamePadding):
        # print(module.static_padding)
        if isinstance(module.static_padding, nn.ZeroPad2d):
            padding = module.static_padding.padding + module.static_padding.padding[:2]
            # print(padding)
        else:
            padding = 0
            
        new_module = nn.Sequential(
                    nn.ConstantPad3d(padding, 0.),
                    nn.Conv3d(module.in_channels,
                               module.out_channels,
                               kernel_size=module.kernel_size[0],
                               stride=module.stride[0],
                               padding=0,
                               bias=module.bias is not None)
        )
    elif isinstance(module, Conv2dSame):
        new_module = Conv3dSame(module.in_channels,
                               module.out_channels,
                               kernel_size=module.kernel_size[0],
                               stride=module.stride[0],
                               padding=0,
                               bias=module.bias is not None
                               )
        
    elif isinstance(module, nn.Conv2d):
        # print(module.kernel_size)
        # print(module.bias)
        # print(module.kernel_size[0], module.padding[0])
        new_module = nn.Conv3d(module.in_channels,
                               module.out_channels,
                               kernel_size=module.kernel_size[0],
                               stride=module.stride[0],
                               padding=module.padding[0],
                               bias=module.bias is not None)
    elif isinstance(module, nn.BatchNorm2d):
        new_module = nn.BatchNorm3d(module.num_features,
                                    module.eps,
                                    module.affine,
                                    module.track_running_stats)
    elif isinstance(module, nn.MaxPool2d):
        new_module = nn.MaxPool3d(kernel_size=module.kernel_size,
                                  stride=module.stride,
                                  padding=module.padding,
                                  dilation=module.dilation,
                                  ceil_mode=module.ceil_mode)
    elif isinstance(module, nn.AdaptiveAvgPool2d):
        new_module = nn.AdaptiveAvgPool3d(module.output_size)

    for name, child_module in new_module.named_children():
        setattr(new_module, name, convert_3d(child_module))

    return new_module


model_3d = convert_3d(model)
total_params = sum(p.numel() for p in model_3d.parameters() if p.requires_grad)
total_params

342070105

In [27]:
x = torch.randn(1, 1, 128, 128, 128)
out = model_3d(x)
out.shape

RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 6 but got size 7 for tensor number 1 in the list.